Compare commits

...

81 Commits

Author SHA1 Message Date
Chiggy-Playz
96153bb177
Fix slash command scope image in docs (#104)
This time its visible properly
2021-10-30 17:18:11 +01:00
Gnome
babbb22462
Fix small typing issue 2021-10-29 14:06:44 +01:00
Gnome
eef8c07379
Optimise _unwrap_slash_groups and similar 2021-10-29 13:33:15 +01:00
Chiggy-Playz
be9e693047
Fix Literal inside Optional not showing choices (#98) 2021-10-27 14:00:21 +01:00
Chiggy-Playz
351bc5bc19
Add Protocol Urls (#103)
Co-authored-by: Stocker <44980366+StockerMC@users.noreply.github.com>
Co-authored-by: Gnome! <45660393+Gnome-py@users.noreply.github.com>
2021-10-27 13:32:50 +01:00
Gnome
5bb88062fa
Basic interaction autocomplete support 2021-10-26 12:27:31 +01:00
Gnome
e99ee71233
Add ctx.defer to help with 3 second slash command response rule.
Acts as `ctx.interaction.response.defer` or loops `ctx.trigger_typing` depending on context.
2021-10-23 21:19:51 +01:00
iDutchy
63dbecf65d
Fix incorrect doc
I forgot about the decorator... min_values can also be 0, so this should prevent confusion
2021-10-20 02:04:28 +02:00
iDutchy
f46d3bfa28
fix incorrect doc
As it seems, this stated min_values must be between 1-25 even tho docs state it must be between 0-25. This changes that doc so that it might prevent confusion in the future
2021-10-20 01:54:11 +02:00
Stocker
983cbb3161
Add the ability to set the option name with commands.Option (#102)
* Add the ability to set the option name with commands.Option
* Document commands.Option.name
2021-10-16 15:00:56 +01:00
Soheab
838d9d8986
Add ability to set a flag description. (#99)
* Add ability to set a flag description.

This PR adds the ability to set a flag description that shows in the slash command options menu.
2021-10-16 13:27:02 +01:00
Chiggy-Playz
e0bf2f9121
Add Channel types support (#100) 2021-10-13 17:34:13 +01:00
Gnome
0abac8698d
Fix slash command flag parsing
Also removes the extra space at the end of fake message content
2021-10-08 20:06:05 +01:00
Gnome
d781af8be5
Remove maintainer list from README.rst
This list became outdated straight away, and is a bad idea in general.
2021-10-08 18:24:22 +01:00
Gnome
9e31aad96d Fix code style issues with Black 2021-10-07 17:34:29 +01:00
Chiggy-Playz
eca1d9a470
Sort events by categories (#88) 2021-10-07 16:48:38 +01:00
Duck
0bbcfd7f33
Update resource links (#65)
* Updated links

* Remove github discussions from getting help
2021-10-06 20:32:48 +01:00
Ian Webster
ec1e2add21
Update user-agent (#92) 2021-10-04 21:11:10 +01:00
Gnome
4277f65051 Implement _FakeSlashMessage.clean_content
Closes #83
2021-10-03 21:05:00 +01:00
Gnome!
3260ec6643
Add improved docs for slash commands (#77)
* Fix command checks actually working

* Current progress on slash command docs

* Improve docs for slash commands further
2021-09-27 01:14:07 -07:00
Chiggy-Playz
d16d2d856f
Sort subcommand names (#68) 2021-09-25 22:43:23 -07:00
Gnome!
456d71d228
Add better support for MENTIONABLE (#74) 2021-09-25 22:41:43 -07:00
Gnome!
093a38527d
Fix slash command prefix to / (#75) 2021-09-25 22:40:35 -07:00
NORXND
163d8e6586
Merge pull request #76
* Fix docs in BadInviteArgument class
2021-09-25 22:39:09 -07:00
Tom
0637a628ca
update workflows (#73)
* modify workflows to fit into one file, fix pyright workflow

* remove redundant pip install

* add check flag to black

* use psf/black for black checker
2021-09-21 14:51:46 -07:00
Gnome!
02c6b07834
Merge pull request #72
* Fix command checks actually working
2021-09-21 14:34:54 -07:00
Gnome!
b810848273
Merge pull request #70
* Fix embed image/thumbnail property
2021-09-21 12:10:16 -07:00
Astrea
cd4bb296f3
Merge pull request #58
* FIxed `userinfo` command not returning an avatar...

* Quick merge conflict fix

* Merge branch '2.0' into converter-example-fix

* Fix code style issues with Black
2021-09-21 11:52:55 -07:00
Arnav Jindal
6a63ce2ed7
Add typechecking for PRS/Commits (#59)
* Create ci.yml

* Create .python-black

* Remove linting
2021-09-21 11:52:03 -07:00
Gnome!
fba7ca420c
Merge pull request #63
* Add ephemeral attachment field

* I did not miss a comma
2021-09-21 11:51:23 -07:00
Gnome!
e65415d3c8
Merge pull request #60
* Rework how checks add attributes to Commmand

* Merge remote-tracking branch 'upstream/2.0' into command-attrs-checks
2021-09-21 11:47:28 -07:00
Astrea
2ecf755372
Merge pull request #57
* FIx _accent_colour being improperly typehinted
2021-09-21 11:37:28 -07:00
Gnome!
00ae8bb18c
Fix all invites to devision server invite (#69) 2021-09-20 21:25:48 +02:00
iDutchy
0638bda719
Fix docs invite
Invite link on docs was still set to dpy, this changes it to edpy
2021-09-19 02:42:43 +02:00
Gnome!
1957fa6011
Implement a least breaking approach to slash commands (#39)
* Most slash command support completed, needs some debugging (and reindent)

* Implement a ctx.send helper for slash commands

* Add group command support

* Add Option converter, fix default optional, fix help command

* Add client.setup and move readying commands to that

* Implement _FakeSlashMessage.from_interaction

* Rename normmal_command to message_command

* Add docs for added params

* Add slash_command_guilds to bot and decos

* Fix merge conflict

* Remove name from commands.Option, wasn't used

* Move slash command processing to BotBase.process_slash_commands

* Create slash_only.py

Basic example for slash commands

* Create slash_and_message.py

Basic example for mixed commands

* Fix slash_command and normal_command bools

* Add some basic error handling for registration

* Fixed converter upload errors

* Fix some logic and make an actual example

* Thanks Safety Jim

* docstrings, *args, and error changes

* Add proper literal support

* Add basic documentation on slash commands

* Fix non-slash command interactions

* Fix ctx.reply in slash command context

* Fix typing on Context.reply

* Fix multiple optional argument sorting

* Update ctx.message docs to mention error instead of warning

* Move slash command creation to BotBase

* Fix code style issues with Black

* Rearrange some stuff and add flag support

* Change some errors and fix interaction.channel fixing

* Fix slash command quoting for *args

Co-authored-by: iDutchy <42503862+iDutchy@users.noreply.github.com>
Co-authored-by: Lint Action <lint-action@samuelmeuli.com>
2021-09-19 01:28:11 +02:00
Astrea
75a23351c4
Revert #42 (#61) 2021-09-09 00:02:02 +02:00
Lint Action
7513c2138f Fix code style issues with Black 2021-09-05 21:34:20 +00:00
IAmTomahawkx
a23dae8604 Merge branch '2.0' of https://github.com/IDevision/enhanced-discord.py into 2.0 2021-09-05 14:33:00 -07:00
IAmTomahawkx
1833e984ce add black workflow, change our code formats. closes #43 2021-09-05 14:32:51 -07:00
Gnome!
53a6b2cb45
Revert "Merge pull request #12" (#56)
This reverts commit 42c0a8d8a5840c00185e367933e61e2565bf7305.
2021-09-05 10:37:51 -07:00
iDutchy
65640ddfc7
Merge pull request #55 from TheMoksej/patch-2
remove unnecessary await
2021-09-05 15:24:45 +02:00
Moksej
14b3188bb8
remove unnecessary await 2021-09-05 13:58:10 +02:00
Arthur
3ffe134895
Merge pull request #44
* Typehint gateway.py

* Add relevant typehints to gateway.py to voice_client.py

* Change EventListener to subclass NamedTuple

* Add return type for DiscordWebSocket.wait_for

* Correct deque typehint

* Remove unnecessary typehints for literals

* Use type aliases

* Merge branch '2.0' into pr7422
2021-09-02 13:50:19 -07:00
Arthur
1032728311
Merge pull request #32
* Add get/fetch_member to ThreadMember objects
2021-09-02 13:43:19 -07:00
Arthur
33470ff196
Merge pull request #31
* Add bots and humans to TextChannel
2021-09-02 13:41:26 -07:00
Arthur
47e42d1648
Merge pull request #42
* implement WelcomeScreen

* copy over the kwargs issue.

* readable variable names

* modernise code

* modernise pt2

* Update discord/welcome_screen.py

* make pylance not cry from my onions

* type http.py

* remove extraneous import
2021-09-02 13:40:11 -07:00
Astrea
4055bafaa5
Merge pull request #47
* Added `on_raw_typing` event
2021-09-02 13:34:39 -07:00
IAmTomahawkx
152b61aabb fix recursionerror caused by a Pull Request 2021-09-02 12:49:38 -07:00
Ahmad Ansori Palembani
f37be7961a
Merge pull request #41
* Fixed `TypeError`

* Handles `EmptyEmbed` inside setter instead of set_

* Remove return and setter docstring
2021-09-02 12:46:56 -07:00
NightSlasher35
0f6db99c59
Merge pull request #22
* add nitro booster color

* Update discord/colour.py
2021-09-02 12:34:41 -07:00
chillymosh
42c0a8d8a5
Merge pull request #12
* Clean up python

* Clean up bot python

* revert lists

* revert commands.bot completely

* extract raise_expected_coro further

* add new lines

* removed erroneous import

* remove hashed line
2021-09-02 12:32:46 -07:00
Arthur
092fbca08f
Merge pull request #21
* [BREAKING] Make case_insensitive default to True on groups and commands
2021-09-02 12:28:03 -07:00
Arthur
13834d1147
Merge pull request #7
* Add try_user to get a user from cache or from the gateway.

* Extract populate_owners into a new coroutine.

* Add a try_owners coroutine to get a list of owners of the bot.

* Fix coding-style.

* Fix a bug where None would be returned in try_owners if the cache was…

* Fix docstring

* Add spacing in the code
2021-09-02 12:24:52 -07:00
Arthur
5d10384576
Merge pull request #27
* Add author_permissions to the Context object as a shortcut to return …
2021-09-02 12:18:26 -07:00
Daud
fc0188d7bc
Merge pull request #49
* Change README title to enhanced-discord.py
2021-09-02 12:17:19 -07:00
iDutchy
a62c0ff0d1
Merge pull request #16 from paris-ci/alias_administrator_to_admin
Alias admin to administrators in permissions. This needs to be tested…
2021-09-02 03:06:30 +02:00
iDutchy
b75be64044
Update permissions.py
A better implementation :)
2021-09-02 03:05:49 +02:00
NightSlasher35
630a842556
Update CONTRIBUTING.md correctly (#29)
* Update CONTRIBUTING.md

* Update CONTRIBUTING.md

Co-authored-by: Tom <47765953+IAmTomahawkx@users.noreply.github.com>
2021-09-01 17:49:41 -07:00
Arthur
c485e08ea0
Add try_member to guild. (#14)
* Add try_member to guild.

This also fix an omission in the fetch_member docs. fetch_member raises NotFound if the given user isn't in the guild.

* Optimize imports.
2021-09-01 17:47:15 -07:00
classerase
dba9a8abb9
Update README.rst (#48) 2021-09-01 15:30:15 -07:00
iDutchy
6f5614373a
Merge pull request #15 from WhoTheOOF/patch-3
Add original dark blurple
2021-09-01 04:40:12 +02:00
iDutchy
2e12746c70
Merge pull request #11 from TheMoksej/patch-2
versionadded needs to be added here
2021-09-01 04:38:26 +02:00
iDutchy
5ef72e4f70
Merge pull request #20 from paris-ci/special_methods
Special methods
2021-09-01 04:27:33 +02:00
iDutchy
7e18d30820
Merge pull request #17 from Astrea49/2.0
Prefer `static_format` over `format` with static assets
2021-09-01 04:26:31 +02:00
iDutchy
923a6a885d
Merge pull request #13 from paris-ci/rework_set_in_embeds
Make `Embed.image` and `Embed.thumbnail` full-featured properties
2021-09-01 04:24:57 +02:00
iDutchy
b28893aa36
Merge pull request #40 from Gnome-py/required-intents
Remove intents.default and make intents a required parameter
2021-09-01 04:21:37 +02:00
Gnome
6e41bd2219 Remove intents.default and make intents a required parameter 2021-08-31 20:53:54 +01:00
Moksej
773ad6f5bf
add back the silent kwarg to message.delete (#9)
* add back the silent kwarg to message.delete

* forgot about versionadded

* shorten the if statement

* simplify raising a bit ig?

* should be versionchanged instead

Co-authored-by: Arthur <site-github@api-d.com>

* remove `Optional` from parameter and doc string

Co-authored-by: Arthur <site-github@api-d.com>
2021-08-29 10:57:07 -07:00
Arthur
de0e8ef108
V2.0 changelog (#8)
* Copy in messages from Danny, verbatim

* Type whats_new

* Add my changes to the changelog

* Fix a typo
2021-08-29 10:55:49 -07:00
Arthur Jovart
64ee792391
Add int() support to Hashable, making it available across the board for AuditLogEntry, *Channel, Guild, Object, Message, ... 2021-08-29 01:21:20 +02:00
Arthur Jovart
22de755059
Add int() and str() support to Message 2021-08-29 01:09:05 +02:00
Arthur Jovart
fa7f8efc8e
Add int() support to Guild 2021-08-29 01:07:26 +02:00
Arthur Jovart
9d1df65af3
Add int() support to Role 2021-08-29 01:06:18 +02:00
Arthur Jovart
3ce86f6cde
Add int() support to Emoji 2021-08-29 01:05:28 +02:00
Arthur Jovart
31e3e99c2b
Add __int__ special method to User and Member 2021-08-29 00:59:29 +02:00
Jadon
cc90d312f5
Add original dark blurple
This adds the old discord dark blurple color as a classmethod for embeds and whatever.
2021-08-28 17:26:46 -05:00
Arthur Jovart
8cdc1f4ad9
Alias admin to administrators in permissions. This needs to be tested, but should be working. 2021-08-29 00:26:26 +02:00
Sonic4999
75f052b8c9 Prefer static_format over format with static assets 2021-08-28 18:24:05 -04:00
Arthur Jovart
c8cdb275c5
Fix set_* function name 2021-08-28 23:29:49 +02:00
Arthur Jovart
406f0ffe04
Make Embed.image and Embed.thumbnail full-featured properties
This avoids the need for set_* methods.
2021-08-28 23:14:26 +02:00
Moksej
a4acbd2e08
versionadded needs to be added here 2021-08-28 22:10:58 +02:00
124 changed files with 8052 additions and 5683 deletions

View File

@ -1,5 +1,7 @@
## Contributing to discord.py ## Contributing to discord.py
Credits to the `original lib` by Rapptz <https://github.com/Rapptz/discord.py>
First off, thanks for taking the time to contribute. It makes the library substantially better. :+1: First off, thanks for taking the time to contribute. It makes the library substantially better. :+1:
The following is a set of guidelines for contributing to the repository. These are guidelines, not hard rules. The following is a set of guidelines for contributing to the repository. These are guidelines, not hard rules.
@ -8,9 +10,9 @@ The following is a set of guidelines for contributing to the repository. These a
Generally speaking questions are better suited in our resources below. Generally speaking questions are better suited in our resources below.
- The official support server: https://discord.gg/r3sSKJJ - The official support server: https://discord.gg/TvqYBrGXEm
- The Discord API server under #python_discord-py: https://discord.gg/discord-api - The Discord API server under #python_discord-py: https://discord.gg/discord-api
- [The FAQ in the documentation](https://discordpy.readthedocs.io/en/latest/faq.html) - [The FAQ in the documentation](https://enhanced-dpy.readthedocs.io/en/latest/faq.html)
- [StackOverflow's `discord.py` tag](https://stackoverflow.com/questions/tagged/discord.py) - [StackOverflow's `discord.py` tag](https://stackoverflow.com/questions/tagged/discord.py)
Please try your best not to ask questions in our issue tracker. Most of them don't belong there unless they provide value to a larger audience. Please try your best not to ask questions in our issue tracker. Most of them don't belong there unless they provide value to a larger audience.
@ -32,13 +34,13 @@ If the bug report is missing this information then it'll take us longer to fix t
## Submitting a Pull Request ## Submitting a Pull Request
Submitting a pull request is fairly simple, just make sure it focuses on a single aspect and doesn't manage to have scope creep and it's probably good to go. It would be incredibly lovely if the style is consistent to that found in the project. This project follows PEP-8 guidelines (mostly) with a column limit of 125. Submitting a pull request is fairly simple, just make sure it focuses on a single aspect and doesn't manage to have scope creep, and it's probably good to go. It would be incredibly lovely if the style is consistent to that found in the project. This project follows the black code format, with a line length limit of `120`
### Git Commit Guidelines ### Git Commit Guidelines
- Use present tense (e.g. "Add feature" not "Added feature") - Use present tense (e.g. "Add feature" not "Added feature")
- Limit all lines to 72 characters or less. - Limit all lines to 120 characters or fewer.
- Reference issues or pull requests outside of the first line. - Reference issues or pull requests outside the first line.
- Please use the shorthand `#123` and not the full URL. - Please use the shorthand `#123` and not the full URL.
- Commits regarding the commands extension must be prefixed with `[commands]` - Commits regarding the commands extension must be prefixed with `[commands]`

View File

@ -6,7 +6,7 @@ body:
attributes: attributes:
value: > value: >
Thanks for taking the time to fill out a bug. Thanks for taking the time to fill out a bug.
If you want real-time support, consider joining our Discord at https://discord.gg/r3sSKJJ instead. If you want real-time support, consider joining our Discord at https://discord.gg/TvqYBrGXEm instead.
Please note that this form is for bugs only! Please note that this form is for bugs only!
- type: input - type: input

View File

@ -5,4 +5,4 @@ contact_links:
url: https://github.com/Rapptz/discord.py/discussions url: https://github.com/Rapptz/discord.py/discussions
- name: Discord Server - name: Discord Server
about: Use our official Discord server to ask for help and questions as well. about: Use our official Discord server to ask for help and questions as well.
url: https://discord.gg/r3sSKJJ url: https://discord.gg/TvqYBrGXEm

View File

@ -1,3 +1,5 @@
<!-- Pull requests that do not fill this information in will likely be closed -->
## Summary ## Summary
<!-- What is this pull request for? Does it fix any issues? --> <!-- What is this pull request for? Does it fix any issues? -->

38
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: CI
on: [push, pull_request]
jobs:
pyright:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Setup node.js (for pyright)
uses: actions/setup-node@v1
with:
node-version: "14"
- name: Run type checking
run: |
npm install -g pyright
pip install .
pyright --lib --verifytypes discord --ignoreexternal
black:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Run linter
uses: psf/black@stable
with:
options: "--line-length 120 --check"
src: "./discord"

1
.python-black Normal file
View File

@ -0,0 +1 @@

View File

@ -1,8 +1,8 @@
discord.py enhanced-discord.py
========== ===================
.. image:: https://discord.com/api/guilds/514232441498763279/embed.png .. image:: https://discord.com/api/guilds/514232441498763279/embed.png
:target: https://discord.gg/PYAfZzpsjG :target: https://discord.gg/TvqYBrGXEm
:alt: Discord server invite :alt: Discord server invite
.. image:: https://img.shields.io/pypi/v/enhanced-dpy.svg .. image:: https://img.shields.io/pypi/v/enhanced-dpy.svg
:target: https://pypi.python.org/pypi/enhanced-dpy :target: https://pypi.python.org/pypi/enhanced-dpy
@ -17,18 +17,6 @@ The Future of enhanced-discord.py
-------------------------- --------------------------
Enhanced discord.py is a fork of Rapptz's discord.py, that went unmaintained (`gist <https://gist.github.com/Rapptz/4a2f62751b9600a31a0d3c78100287f1>`_) Enhanced discord.py is a fork of Rapptz's discord.py, that went unmaintained (`gist <https://gist.github.com/Rapptz/4a2f62751b9600a31a0d3c78100287f1>`_)
It is currently maintained by (in alphabetical order)
- Chillymosh#8175
- Daggy#9889
- dank Had0cK#6081
- Dutchy#6127
- Eyesofcreeper#0001
- Gnome!#6669
- IAmTomahawkx#1000
- Jadon#2494
An overview of added features is available on the `custom features page <https://enhanced-dpy.readthedocs.io/en/latest/index.html#custom-features>`_. An overview of added features is available on the `custom features page <https://enhanced-dpy.readthedocs.io/en/latest/index.html#custom-features>`_.
Key Features Key Features
@ -59,7 +47,7 @@ To install the development version, do the following:
.. code:: sh .. code:: sh
$ git clone https://github.com/iDevision/enhanced-discord.py $ git clone https://github.com/iDevision/enhanced-discord.py
$ cd discord.py $ cd enhanced-discord.py
$ python3 -m pip install -U .[voice] $ python3 -m pip install -U .[voice]
@ -117,5 +105,5 @@ Links
------ ------
- `Documentation <https://enhanced-dpy.readthedocs.io/en/latest/index.html>`_ - `Documentation <https://enhanced-dpy.readthedocs.io/en/latest/index.html>`_
- `Official Discord Server <https://discord.gg/PYAfZzpsjG>`_ - `Official Discord Server <https://discord.gg/TvqYBrGXEm>`_
- `Discord API <https://discord.gg/discord-api>`_ - `Discord API <https://discord.gg/discord-api>`_

View File

@ -9,13 +9,13 @@ A basic wrapper for the Discord API.
""" """
__title__ = 'discord' __title__ = "discord"
__author__ = 'Rapptz' __author__ = "Rapptz"
__license__ = 'MIT' __license__ = "MIT"
__copyright__ = 'Copyright 2015-present Rapptz' __copyright__ = "Copyright 2015-present Rapptz"
__version__ = '2.0.0a' __version__ = "2.0.0a"
__path__ = __import__('pkgutil').extend_path(__path__, __name__) __path__ = __import__("pkgutil").extend_path(__path__, __name__)
import logging import logging
from typing import NamedTuple, Literal from typing import NamedTuple, Literal
@ -69,6 +69,6 @@ class VersionInfo(NamedTuple):
serial: int serial: int
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0) version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel="alpha", serial=0)
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())

View File

@ -31,26 +31,29 @@ import pkg_resources
import aiohttp import aiohttp
import platform import platform
def show_version(): def show_version():
entries = [] entries = []
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info)) entries.append("- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(sys.version_info))
version_info = discord.version_info version_info = discord.version_info
entries.append('- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(version_info)) entries.append("- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info))
if version_info.releaselevel != 'final': if version_info.releaselevel != "final":
pkg = pkg_resources.get_distribution('discord.py') pkg = pkg_resources.get_distribution("discord.py")
if pkg: if pkg:
entries.append(f' - discord.py pkg_resources: v{pkg.version}') entries.append(f" - discord.py pkg_resources: v{pkg.version}")
entries.append(f'- aiohttp v{aiohttp.__version__}') entries.append(f"- aiohttp v{aiohttp.__version__}")
uname = platform.uname() uname = platform.uname()
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname)) entries.append("- system info: {0.system} {0.release} {0.version}".format(uname))
print('\n'.join(entries)) print("\n".join(entries))
def core(parser, args): def core(parser, args):
if args.version: if args.version:
show_version() show_version()
_bot_template = """#!/usr/bin/env python3 _bot_template = """#!/usr/bin/env python3
from discord.ext import commands from discord.ext import commands
@ -120,7 +123,7 @@ def setup(bot):
bot.add_cog({name}(bot)) bot.add_cog({name}(bot))
''' '''
_cog_extras = ''' _cog_extras = """
def cog_unload(self): def cog_unload(self):
# clean up logic goes here # clean up logic goes here
pass pass
@ -149,22 +152,22 @@ _cog_extras = '''
# called after a command is called here # called after a command is called here
pass pass
''' """
# certain file names and directory names are forbidden # certain file names and directory names are forbidden
# see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx # see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx
# although some of this doesn't apply to Linux, we might as well be consistent # although some of this doesn't apply to Linux, we might as well be consistent
_base_table = { _base_table = {
'<': '-', "<": "-",
'>': '-', ">": "-",
':': '-', ":": "-",
'"': '-', '"': "-",
# '/': '-', these are fine # '/': '-', these are fine
# '\\': '-', # '\\': '-',
'|': '-', "|": "-",
'?': '-', "?": "-",
'*': '-', "*": "-",
} }
# NUL (0) and 1-31 are disallowed # NUL (0) and 1-31 are disallowed
@ -172,21 +175,45 @@ _base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table) _translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False): def to_path(parser, name, *, replace_spaces=False):
if isinstance(name, Path): if isinstance(name, Path):
return name return name
if sys.platform == 'win32': if sys.platform == "win32":
forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \ forbidden = (
'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9') "CON",
"PRN",
"AUX",
"NUL",
"COM1",
"COM2",
"COM3",
"COM4",
"COM5",
"COM6",
"COM7",
"COM8",
"COM9",
"LPT1",
"LPT2",
"LPT3",
"LPT4",
"LPT5",
"LPT6",
"LPT7",
"LPT8",
"LPT9",
)
if len(name) <= 4 and name.upper() in forbidden: if len(name) <= 4 and name.upper() in forbidden:
parser.error('invalid directory name given, use a different one') parser.error("invalid directory name given, use a different one")
name = name.translate(_translation_table) name = name.translate(_translation_table)
if replace_spaces: if replace_spaces:
name = name.replace(' ', '-') name = name.replace(" ", "-")
return Path(name) return Path(name)
def newbot(parser, args): def newbot(parser, args):
new_directory = to_path(parser, args.directory) / to_path(parser, args.name) new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
@ -195,106 +222,114 @@ def newbot(parser, args):
try: try:
new_directory.mkdir(exist_ok=True, parents=True) new_directory.mkdir(exist_ok=True, parents=True)
except OSError as exc: except OSError as exc:
parser.error(f'could not create our bot directory ({exc})') parser.error(f"could not create our bot directory ({exc})")
cogs = new_directory / 'cogs' cogs = new_directory / "cogs"
try: try:
cogs.mkdir(exist_ok=True) cogs.mkdir(exist_ok=True)
init = cogs / '__init__.py' init = cogs / "__init__.py"
init.touch() init.touch()
except OSError as exc: except OSError as exc:
print(f'warning: could not create cogs directory ({exc})') print(f"warning: could not create cogs directory ({exc})")
try: try:
with open(str(new_directory / 'config.py'), 'w', encoding='utf-8') as fp: with open(str(new_directory / "config.py"), "w", encoding="utf-8") as fp:
fp.write('token = "place your token here"\ncogs = []\n') fp.write('token = "place your token here"\ncogs = []\n')
except OSError as exc: except OSError as exc:
parser.error(f'could not create config file ({exc})') parser.error(f"could not create config file ({exc})")
try: try:
with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp: with open(str(new_directory / "bot.py"), "w", encoding="utf-8") as fp:
base = 'Bot' if not args.sharded else 'AutoShardedBot' base = "Bot" if not args.sharded else "AutoShardedBot"
fp.write(_bot_template.format(base=base, prefix=args.prefix)) fp.write(_bot_template.format(base=base, prefix=args.prefix))
except OSError as exc: except OSError as exc:
parser.error(f'could not create bot file ({exc})') parser.error(f"could not create bot file ({exc})")
if not args.no_git: if not args.no_git:
try: try:
with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp: with open(str(new_directory / ".gitignore"), "w", encoding="utf-8") as fp:
fp.write(_gitignore_template) fp.write(_gitignore_template)
except OSError as exc: except OSError as exc:
print(f'warning: could not create .gitignore file ({exc})') print(f"warning: could not create .gitignore file ({exc})")
print("successfully made bot at", new_directory)
print('successfully made bot at', new_directory)
def newcog(parser, args): def newcog(parser, args):
cog_dir = to_path(parser, args.directory) cog_dir = to_path(parser, args.directory)
try: try:
cog_dir.mkdir(exist_ok=True) cog_dir.mkdir(exist_ok=True)
except OSError as exc: except OSError as exc:
print(f'warning: could not create cogs directory ({exc})') print(f"warning: could not create cogs directory ({exc})")
directory = cog_dir / to_path(parser, args.name) directory = cog_dir / to_path(parser, args.name)
directory = directory.with_suffix('.py') directory = directory.with_suffix(".py")
try: try:
with open(str(directory), 'w', encoding='utf-8') as fp: with open(str(directory), "w", encoding="utf-8") as fp:
attrs = '' attrs = ""
extra = _cog_extras if args.full else '' extra = _cog_extras if args.full else ""
if args.class_name: if args.class_name:
name = args.class_name name = args.class_name
else: else:
name = str(directory.stem) name = str(directory.stem)
if '-' in name or '_' in name: if "-" in name or "_" in name:
translation = str.maketrans('-_', ' ') translation = str.maketrans("-_", " ")
name = name.translate(translation).title().replace(' ', '') name = name.translate(translation).title().replace(" ", "")
else: else:
name = name.title() name = name.title()
if args.display_name: if args.display_name:
attrs += f', name="{args.display_name}"' attrs += f', name="{args.display_name}"'
if args.hide_commands: if args.hide_commands:
attrs += ', command_attrs=dict(hidden=True)' attrs += ", command_attrs=dict(hidden=True)"
fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs)) fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs))
except OSError as exc: except OSError as exc:
parser.error(f'could not create cog file ({exc})') parser.error(f"could not create cog file ({exc})")
else: else:
print('successfully made cog at', directory) print("successfully made cog at", directory)
def add_newbot_args(subparser): def add_newbot_args(subparser):
parser = subparser.add_parser('newbot', help='creates a command bot project quickly') parser = subparser.add_parser("newbot", help="creates a command bot project quickly")
parser.set_defaults(func=newbot) parser.set_defaults(func=newbot)
parser.add_argument('name', help='the bot project name') parser.add_argument("name", help="the bot project name")
parser.add_argument('directory', help='the directory to place it in (default: .)', nargs='?', default=Path.cwd()) parser.add_argument("directory", help="the directory to place it in (default: .)", nargs="?", default=Path.cwd())
parser.add_argument('--prefix', help='the bot prefix (default: $)', default='$', metavar='<prefix>') parser.add_argument("--prefix", help="the bot prefix (default: $)", default="$", metavar="<prefix>")
parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true') parser.add_argument("--sharded", help="whether to use AutoShardedBot", action="store_true")
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git') parser.add_argument("--no-git", help="do not create a .gitignore file", action="store_true", dest="no_git")
def add_newcog_args(subparser): def add_newcog_args(subparser):
parser = subparser.add_parser('newcog', help='creates a new cog template quickly') parser = subparser.add_parser("newcog", help="creates a new cog template quickly")
parser.set_defaults(func=newcog) parser.set_defaults(func=newcog)
parser.add_argument('name', help='the cog name') parser.add_argument("name", help="the cog name")
parser.add_argument('directory', help='the directory to place it in (default: cogs)', nargs='?', default=Path('cogs')) parser.add_argument(
parser.add_argument('--class-name', help='the class name of the cog (default: <name>)', dest='class_name') "directory", help="the directory to place it in (default: cogs)", nargs="?", default=Path("cogs")
parser.add_argument('--display-name', help='the cog name (default: <name>)') )
parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true') parser.add_argument("--class-name", help="the class name of the cog (default: <name>)", dest="class_name")
parser.add_argument('--full', help='add all special methods as well', action='store_true') parser.add_argument("--display-name", help="the cog name (default: <name>)")
parser.add_argument("--hide-commands", help="whether to hide all commands in the cog", action="store_true")
parser.add_argument("--full", help="add all special methods as well", action="store_true")
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py') parser = argparse.ArgumentParser(prog="discord", description="Tools for helping with discord.py")
parser.add_argument('-v', '--version', action='store_true', help='shows the library version') parser.add_argument("-v", "--version", action="store_true", help="shows the library version")
parser.set_defaults(func=core) parser.set_defaults(func=core)
subparser = parser.add_subparsers(dest='subcommand', title='subcommands') subparser = parser.add_subparsers(dest="subcommand", title="subcommands")
add_newbot_args(subparser) add_newbot_args(subparser)
add_newcog_args(subparser) add_newcog_args(subparser)
return parser, parser.parse_args() return parser, parser.parse_args()
def main(): def main():
parser, args = parse_args() parser, args = parse_args()
args.func(parser, args) args.func(parser, args)
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View File

@ -56,15 +56,15 @@ from .sticker import GuildSticker, StickerItem
from . import utils from . import utils
__all__ = ( __all__ = (
'Snowflake', "Snowflake",
'User', "User",
'PrivateChannel', "PrivateChannel",
'GuildChannel', "GuildChannel",
'Messageable', "Messageable",
'Connectable', "Connectable",
) )
T = TypeVar('T', bound=VoiceProtocol) T = TypeVar("T", bound=VoiceProtocol)
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import datetime from datetime import datetime
@ -98,7 +98,7 @@ MISSING = utils.MISSING
class _Undefined: class _Undefined:
def __repr__(self) -> str: def __repr__(self) -> str:
return 'see-below' return "see-below"
_undefined: Any = _Undefined() _undefined: Any = _Undefined()
@ -189,23 +189,23 @@ class PrivateChannel(Snowflake, Protocol):
class _Overwrites: class _Overwrites:
__slots__ = ('id', 'allow', 'deny', 'type') __slots__ = ("id", "allow", "deny", "type")
ROLE = 0 ROLE = 0
MEMBER = 1 MEMBER = 1
def __init__(self, data: PermissionOverwritePayload): def __init__(self, data: PermissionOverwritePayload):
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.allow: int = int(data.get('allow', 0)) self.allow: int = int(data.get("allow", 0))
self.deny: int = int(data.get('deny', 0)) self.deny: int = int(data.get("deny", 0))
self.type: OverwriteType = data['type'] self.type: OverwriteType = data["type"]
def _asdict(self) -> PermissionOverwritePayload: def _asdict(self) -> PermissionOverwritePayload:
return { return {
'id': self.id, "id": self.id,
'allow': str(self.allow), "allow": str(self.allow),
'deny': str(self.deny), "deny": str(self.deny),
'type': self.type, "type": self.type,
} }
def is_role(self) -> bool: def is_role(self) -> bool:
@ -215,7 +215,7 @@ class _Overwrites:
return self.type == 1 return self.type == 1
GCH = TypeVar('GCH', bound='GuildChannel') GCH = TypeVar("GCH", bound="GuildChannel")
class GuildChannel: class GuildChannel:
@ -276,7 +276,7 @@ class GuildChannel:
reason: Optional[str], reason: Optional[str],
) -> None: ) -> None:
if position < 0: if position < 0:
raise InvalidArgument('Channel position cannot be less than 0.') raise InvalidArgument("Channel position cannot be less than 0.")
http = self._state.http http = self._state.http
bucket = self._sorting_bucket bucket = self._sorting_bucket
@ -297,7 +297,7 @@ class GuildChannel:
payload = [] payload = []
for index, c in enumerate(channels): for index, c in enumerate(channels):
d: Dict[str, Any] = {'id': c.id, 'position': index} d: Dict[str, Any] = {"id": c.id, "position": index}
if parent_id is not _undefined and c.id == self.id: if parent_id is not _undefined and c.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions) d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d) payload.append(d)
@ -306,81 +306,81 @@ class GuildChannel:
async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]: async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]:
try: try:
parent = options.pop('category') parent = options.pop("category")
except KeyError: except KeyError:
parent_id = _undefined parent_id = _undefined
else: else:
parent_id = parent and parent.id parent_id = parent and parent.id
try: try:
options['rate_limit_per_user'] = options.pop('slowmode_delay') options["rate_limit_per_user"] = options.pop("slowmode_delay")
except KeyError: except KeyError:
pass pass
try: try:
rtc_region = options.pop('rtc_region') rtc_region = options.pop("rtc_region")
except KeyError: except KeyError:
pass pass
else: else:
options['rtc_region'] = None if rtc_region is None else str(rtc_region) options["rtc_region"] = None if rtc_region is None else str(rtc_region)
try: try:
video_quality_mode = options.pop('video_quality_mode') video_quality_mode = options.pop("video_quality_mode")
except KeyError: except KeyError:
pass pass
else: else:
options['video_quality_mode'] = int(video_quality_mode) options["video_quality_mode"] = int(video_quality_mode)
lock_permissions = options.pop('sync_permissions', False) lock_permissions = options.pop("sync_permissions", False)
try: try:
position = options.pop('position') position = options.pop("position")
except KeyError: except KeyError:
if parent_id is not _undefined: if parent_id is not _undefined:
if lock_permissions: if lock_permissions:
category = self.guild.get_channel(parent_id) category = self.guild.get_channel(parent_id)
if category: if category:
options['permission_overwrites'] = [c._asdict() for c in category._overwrites] options["permission_overwrites"] = [c._asdict() for c in category._overwrites]
options['parent_id'] = parent_id options["parent_id"] = parent_id
elif lock_permissions and self.category_id is not None: elif lock_permissions and self.category_id is not None:
# if we're syncing permissions on a pre-existing channel category without changing it # if we're syncing permissions on a pre-existing channel category without changing it
# we need to update the permissions to point to the pre-existing category # we need to update the permissions to point to the pre-existing category
category = self.guild.get_channel(self.category_id) category = self.guild.get_channel(self.category_id)
if category: if category:
options['permission_overwrites'] = [c._asdict() for c in category._overwrites] options["permission_overwrites"] = [c._asdict() for c in category._overwrites]
else: else:
await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason) await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason)
overwrites = options.get('overwrites', None) overwrites = options.get("overwrites", None)
if overwrites is not None: if overwrites is not None:
perms = [] perms = []
for target, perm in overwrites.items(): for target, perm in overwrites.items():
if not isinstance(perm, PermissionOverwrite): if not isinstance(perm, PermissionOverwrite):
raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}') raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}")
allow, deny = perm.pair() allow, deny = perm.pair()
payload = { payload = {
'allow': allow.value, "allow": allow.value,
'deny': deny.value, "deny": deny.value,
'id': target.id, "id": target.id,
} }
if isinstance(target, Role): if isinstance(target, Role):
payload['type'] = _Overwrites.ROLE payload["type"] = _Overwrites.ROLE
else: else:
payload['type'] = _Overwrites.MEMBER payload["type"] = _Overwrites.MEMBER
perms.append(payload) perms.append(payload)
options['permission_overwrites'] = perms options["permission_overwrites"] = perms
try: try:
ch_type = options['type'] ch_type = options["type"]
except KeyError: except KeyError:
pass pass
else: else:
if not isinstance(ch_type, ChannelType): if not isinstance(ch_type, ChannelType):
raise InvalidArgument('type field must be of type ChannelType') raise InvalidArgument("type field must be of type ChannelType")
options['type'] = ch_type.value options["type"] = ch_type.value
if options: if options:
return await self._state.http.edit_channel(self.id, reason=reason, **options) return await self._state.http.edit_channel(self.id, reason=reason, **options)
@ -390,7 +390,7 @@ class GuildChannel:
everyone_index = 0 everyone_index = 0
everyone_id = self.guild.id everyone_id = self.guild.id
for index, overridden in enumerate(data.get('permission_overwrites', [])): for index, overridden in enumerate(data.get("permission_overwrites", [])):
overwrite = _Overwrites(overridden) overwrite = _Overwrites(overridden)
self._overwrites.append(overwrite) self._overwrites.append(overwrite)
@ -429,7 +429,7 @@ class GuildChannel:
@property @property
def mention(self) -> str: def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel.""" """:class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>' return f"<#{self.id}>"
@property @property
def created_at(self) -> datetime: def created_at(self) -> datetime:
@ -779,18 +779,18 @@ class GuildChannel:
elif isinstance(target, Role): elif isinstance(target, Role):
perm_type = _Overwrites.ROLE perm_type = _Overwrites.ROLE
else: else:
raise InvalidArgument('target parameter must be either Member or Role') raise InvalidArgument("target parameter must be either Member or Role")
if overwrite is _undefined: if overwrite is _undefined:
if len(permissions) == 0: if len(permissions) == 0:
raise InvalidArgument('No overwrite provided.') raise InvalidArgument("No overwrite provided.")
try: try:
overwrite = PermissionOverwrite(**permissions) overwrite = PermissionOverwrite(**permissions)
except (ValueError, TypeError): except (ValueError, TypeError):
raise InvalidArgument('Invalid permissions given to keyword arguments.') raise InvalidArgument("Invalid permissions given to keyword arguments.")
else: else:
if len(permissions) > 0: if len(permissions) > 0:
raise InvalidArgument('Cannot mix overwrite and keyword arguments.') raise InvalidArgument("Cannot mix overwrite and keyword arguments.")
# TODO: wait for event # TODO: wait for event
@ -800,7 +800,7 @@ class GuildChannel:
(allow, deny) = overwrite.pair() (allow, deny) = overwrite.pair()
await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason) await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason)
else: else:
raise InvalidArgument('Invalid overwrite type provided.') raise InvalidArgument("Invalid overwrite type provided.")
async def _clone_impl( async def _clone_impl(
self: GCH, self: GCH,
@ -809,9 +809,9 @@ class GuildChannel:
name: Optional[str] = None, name: Optional[str] = None,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> GCH: ) -> GCH:
base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites] base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites]
base_attrs['parent_id'] = self.category_id base_attrs["parent_id"] = self.category_id
base_attrs['name'] = name or self.name base_attrs["name"] = name or self.name
guild_id = self.guild.id guild_id = self.guild.id
cls = self.__class__ cls = self.__class__
data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs)
@ -964,14 +964,14 @@ class GuildChannel:
if not kwargs: if not kwargs:
return return
beginning, end = kwargs.get('beginning'), kwargs.get('end') beginning, end = kwargs.get("beginning"), kwargs.get("end")
before, after = kwargs.get('before'), kwargs.get('after') before, after = kwargs.get("before"), kwargs.get("after")
offset = kwargs.get('offset', 0) offset = kwargs.get("offset", 0)
if sum(bool(a) for a in (beginning, end, before, after)) > 1: if sum(bool(a) for a in (beginning, end, before, after)) > 1:
raise InvalidArgument('Only one of [before, after, end, beginning] can be used.') raise InvalidArgument("Only one of [before, after, end, beginning] can be used.")
bucket = self._sorting_bucket bucket = self._sorting_bucket
parent_id = kwargs.get('category', MISSING) parent_id = kwargs.get("category", MISSING)
# fmt: off # fmt: off
channels: List[GuildChannel] channels: List[GuildChannel]
if parent_id not in (MISSING, None): if parent_id not in (MISSING, None):
@ -1011,14 +1011,14 @@ class GuildChannel:
index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None) index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None)
if index is None: if index is None:
raise InvalidArgument('Could not resolve appropriate move position') raise InvalidArgument("Could not resolve appropriate move position")
channels.insert(max((index + offset), 0), self) channels.insert(max((index + offset), 0), self)
payload = [] payload = []
lock_permissions = kwargs.get('sync_permissions', False) lock_permissions = kwargs.get("sync_permissions", False)
reason = kwargs.get('reason') reason = kwargs.get("reason")
for index, channel in enumerate(channels): for index, channel in enumerate(channels):
d = {'id': channel.id, 'position': index} d = {"id": channel.id, "position": index}
if parent_id is not MISSING and channel.id == self.id: if parent_id is not MISSING and channel.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions) d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d) payload.append(d)
@ -1332,14 +1332,14 @@ class Messageable:
content = str(content) if content is not None else None content = str(content) if content is not None else None
if embed is not None and embeds is not None: if embed is not None and embeds is not None:
raise InvalidArgument('cannot pass both embed and embeds parameter to send()') raise InvalidArgument("cannot pass both embed and embeds parameter to send()")
if embed is not None: if embed is not None:
embed = embed.to_dict() embed = embed.to_dict()
elif embeds is not None: elif embeds is not None:
if len(embeds) > 10: if len(embeds) > 10:
raise InvalidArgument('embeds parameter must be a list of up to 10 elements') raise InvalidArgument("embeds parameter must be a list of up to 10 elements")
embeds = [embed.to_dict() for embed in embeds] embeds = [embed.to_dict() for embed in embeds]
if stickers is not None: if stickers is not None:
@ -1355,28 +1355,30 @@ class Messageable:
if mention_author is not None: if mention_author is not None:
allowed_mentions = allowed_mentions or AllowedMentions().to_dict() allowed_mentions = allowed_mentions or AllowedMentions().to_dict()
allowed_mentions['replied_user'] = bool(mention_author) allowed_mentions["replied_user"] = bool(mention_author)
if reference is not None: if reference is not None:
try: try:
reference = reference.to_message_reference_dict() reference = reference.to_message_reference_dict()
except AttributeError: except AttributeError:
raise InvalidArgument('reference parameter must be Message, MessageReference, or PartialMessage') from None raise InvalidArgument(
"reference parameter must be Message, MessageReference, or PartialMessage"
) from None
if view: if view:
if not hasattr(view, '__discord_ui_view__'): if not hasattr(view, "__discord_ui_view__"):
raise InvalidArgument(f'view parameter must be View not {view.__class__!r}') raise InvalidArgument(f"view parameter must be View not {view.__class__!r}")
components = view.to_components() components = view.to_components()
else: else:
components = None components = None
if file is not None and files is not None: if file is not None and files is not None:
raise InvalidArgument('cannot pass both file and files parameter to send()') raise InvalidArgument("cannot pass both file and files parameter to send()")
if file is not None: if file is not None:
if not isinstance(file, File): if not isinstance(file, File):
raise InvalidArgument('file parameter must be File') raise InvalidArgument("file parameter must be File")
try: try:
data = await state.http.send_files( data = await state.http.send_files(
@ -1397,9 +1399,9 @@ class Messageable:
elif files is not None: elif files is not None:
if len(files) > 10: if len(files) > 10:
raise InvalidArgument('files parameter must be a list of up to 10 elements') raise InvalidArgument("files parameter must be a list of up to 10 elements")
elif not all(isinstance(file, File) for file in files): elif not all(isinstance(file, File) for file in files):
raise InvalidArgument('files parameter must be a list of File') raise InvalidArgument("files parameter must be a list of File")
try: try:
data = await state.http.send_files( data = await state.http.send_files(
@ -1666,13 +1668,13 @@ class Connectable(Protocol):
state = self._state state = self._state
if state._get_voice_client(key_id): if state._get_voice_client(key_id):
raise ClientException('Already connected to a voice channel.') raise ClientException("Already connected to a voice channel.")
client = state._get_client() client = state._get_client()
voice = cls(client, self) voice = cls(client, self)
if not isinstance(voice, VoiceProtocol): if not isinstance(voice, VoiceProtocol):
raise TypeError('Type must meet VoiceProtocol abstract base class.') raise TypeError("Type must meet VoiceProtocol abstract base class.")
state._add_voice_client(key_id, voice) state._add_voice_client(key_id, voice)

View File

@ -34,12 +34,12 @@ from .partial_emoji import PartialEmoji
from .utils import _get_as_snowflake from .utils import _get_as_snowflake
__all__ = ( __all__ = (
'BaseActivity', "BaseActivity",
'Activity', "Activity",
'Streaming', "Streaming",
'Game', "Game",
'Spotify', "Spotify",
'CustomActivity', "CustomActivity",
) )
"""If curious, this is the current schema for an activity. """If curious, this is the current schema for an activity.
@ -119,10 +119,10 @@ class BaseActivity:
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
__slots__ = ('_created_at',) __slots__ = ("_created_at",)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._created_at: Optional[float] = kwargs.pop('created_at', None) self._created_at: Optional[float] = kwargs.pop("created_at", None)
@property @property
def created_at(self) -> Optional[datetime.datetime]: def created_at(self) -> Optional[datetime.datetime]:
@ -199,58 +199,58 @@ class Activity(BaseActivity):
""" """
__slots__ = ( __slots__ = (
'state', "state",
'details', "details",
'_created_at', "_created_at",
'timestamps', "timestamps",
'assets', "assets",
'party', "party",
'flags', "flags",
'sync_id', "sync_id",
'session_id', "session_id",
'type', "type",
'name', "name",
'url', "url",
'application_id', "application_id",
'emoji', "emoji",
'buttons', "buttons",
) )
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.state: Optional[str] = kwargs.pop('state', None) self.state: Optional[str] = kwargs.pop("state", None)
self.details: Optional[str] = kwargs.pop('details', None) self.details: Optional[str] = kwargs.pop("details", None)
self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {}) self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {})
self.assets: ActivityAssets = kwargs.pop('assets', {}) self.assets: ActivityAssets = kwargs.pop("assets", {})
self.party: ActivityParty = kwargs.pop('party', {}) self.party: ActivityParty = kwargs.pop("party", {})
self.application_id: Optional[int] = _get_as_snowflake(kwargs, 'application_id') self.application_id: Optional[int] = _get_as_snowflake(kwargs, "application_id")
self.name: Optional[str] = kwargs.pop('name', None) self.name: Optional[str] = kwargs.pop("name", None)
self.url: Optional[str] = kwargs.pop('url', None) self.url: Optional[str] = kwargs.pop("url", None)
self.flags: int = kwargs.pop('flags', 0) self.flags: int = kwargs.pop("flags", 0)
self.sync_id: Optional[str] = kwargs.pop('sync_id', None) self.sync_id: Optional[str] = kwargs.pop("sync_id", None)
self.session_id: Optional[str] = kwargs.pop('session_id', None) self.session_id: Optional[str] = kwargs.pop("session_id", None)
self.buttons: List[ActivityButton] = kwargs.pop('buttons', []) self.buttons: List[ActivityButton] = kwargs.pop("buttons", [])
activity_type = kwargs.pop('type', -1) activity_type = kwargs.pop("type", -1)
self.type: ActivityType = ( self.type: ActivityType = (
activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type) activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type)
) )
emoji = kwargs.pop('emoji', None) emoji = kwargs.pop("emoji", None)
self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = ( attrs = (
('type', self.type), ("type", self.type),
('name', self.name), ("name", self.name),
('url', self.url), ("url", self.url),
('details', self.details), ("details", self.details),
('application_id', self.application_id), ("application_id", self.application_id),
('session_id', self.session_id), ("session_id", self.session_id),
('emoji', self.emoji), ("emoji", self.emoji),
) )
inner = ' '.join('%s=%r' % t for t in attrs) inner = " ".join("%s=%r" % t for t in attrs)
return f'<Activity {inner}>' return f"<Activity {inner}>"
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
ret: Dict[str, Any] = {} ret: Dict[str, Any] = {}
@ -263,16 +263,16 @@ class Activity(BaseActivity):
continue continue
ret[attr] = value ret[attr] = value
ret['type'] = int(self.type) ret["type"] = int(self.type)
if self.emoji: if self.emoji:
ret['emoji'] = self.emoji.to_dict() ret["emoji"] = self.emoji.to_dict()
return ret return ret
@property @property
def start(self) -> Optional[datetime.datetime]: def start(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable.""" """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable."""
try: try:
timestamp = self.timestamps['start'] / 1000 timestamp = self.timestamps["start"] / 1000
except KeyError: except KeyError:
return None return None
else: else:
@ -282,7 +282,7 @@ class Activity(BaseActivity):
def end(self) -> Optional[datetime.datetime]: def end(self) -> Optional[datetime.datetime]:
"""Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable.""" """Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable."""
try: try:
timestamp = self.timestamps['end'] / 1000 timestamp = self.timestamps["end"] / 1000
except KeyError: except KeyError:
return None return None
else: else:
@ -295,11 +295,11 @@ class Activity(BaseActivity):
return None return None
try: try:
large_image = self.assets['large_image'] large_image = self.assets["large_image"]
except KeyError: except KeyError:
return None return None
else: else:
return Asset.BASE + f'/app-assets/{self.application_id}/{large_image}.png' return Asset.BASE + f"/app-assets/{self.application_id}/{large_image}.png"
@property @property
def small_image_url(self) -> Optional[str]: def small_image_url(self) -> Optional[str]:
@ -308,21 +308,21 @@ class Activity(BaseActivity):
return None return None
try: try:
small_image = self.assets['small_image'] small_image = self.assets["small_image"]
except KeyError: except KeyError:
return None return None
else: else:
return Asset.BASE + f'/app-assets/{self.application_id}/{small_image}.png' return Asset.BASE + f"/app-assets/{self.application_id}/{small_image}.png"
@property @property
def large_image_text(self) -> Optional[str]: def large_image_text(self) -> Optional[str]:
"""Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable.""" """Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable."""
return self.assets.get('large_text', None) return self.assets.get("large_text", None)
@property @property
def small_image_text(self) -> Optional[str]: def small_image_text(self) -> Optional[str]:
"""Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable.""" """Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable."""
return self.assets.get('small_text', None) return self.assets.get("small_text", None)
class Game(BaseActivity): class Game(BaseActivity):
@ -359,20 +359,20 @@ class Game(BaseActivity):
The game's name. The game's name.
""" """
__slots__ = ('name', '_end', '_start') __slots__ = ("name", "_end", "_start")
def __init__(self, name: str, **extra): def __init__(self, name: str, **extra):
super().__init__(**extra) super().__init__(**extra)
self.name: str = name self.name: str = name
try: try:
timestamps: ActivityTimestamps = extra['timestamps'] timestamps: ActivityTimestamps = extra["timestamps"]
except KeyError: except KeyError:
self._start = 0 self._start = 0
self._end = 0 self._end = 0
else: else:
self._start = timestamps.get('start', 0) self._start = timestamps.get("start", 0)
self._end = timestamps.get('end', 0) self._end = timestamps.get("end", 0)
@property @property
def type(self) -> ActivityType: def type(self) -> ActivityType:
@ -400,15 +400,15 @@ class Game(BaseActivity):
return str(self.name) return str(self.name)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Game name={self.name!r}>' return f"<Game name={self.name!r}>"
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
timestamps: Dict[str, Any] = {} timestamps: Dict[str, Any] = {}
if self._start: if self._start:
timestamps['start'] = self._start timestamps["start"] = self._start
if self._end: if self._end:
timestamps['end'] = self._end timestamps["end"] = self._end
# fmt: off # fmt: off
return { return {
@ -473,16 +473,16 @@ class Streaming(BaseActivity):
A dictionary comprising of similar keys than those in :attr:`Activity.assets`. A dictionary comprising of similar keys than those in :attr:`Activity.assets`.
""" """
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets') __slots__ = ("platform", "name", "game", "url", "details", "assets")
def __init__(self, *, name: Optional[str], url: str, **extra: Any): def __init__(self, *, name: Optional[str], url: str, **extra: Any):
super().__init__(**extra) super().__init__(**extra)
self.platform: Optional[str] = name self.platform: Optional[str] = name
self.name: Optional[str] = extra.pop('details', name) self.name: Optional[str] = extra.pop("details", name)
self.game: Optional[str] = extra.pop('state', None) self.game: Optional[str] = extra.pop("state", None)
self.url: str = url self.url: str = url
self.details: Optional[str] = extra.pop('details', self.name) # compatibility self.details: Optional[str] = extra.pop("details", self.name) # compatibility
self.assets: ActivityAssets = extra.pop('assets', {}) self.assets: ActivityAssets = extra.pop("assets", {})
@property @property
def type(self) -> ActivityType: def type(self) -> ActivityType:
@ -496,7 +496,7 @@ class Streaming(BaseActivity):
return str(self.name) return str(self.name)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Streaming name={self.name!r}>' return f"<Streaming name={self.name!r}>"
@property @property
def twitch_name(self): def twitch_name(self):
@ -507,11 +507,11 @@ class Streaming(BaseActivity):
""" """
try: try:
name = self.assets['large_image'] name = self.assets["large_image"]
except KeyError: except KeyError:
return None return None
else: else:
return name[7:] if name[:7] == 'twitch:' else None return name[7:] if name[:7] == "twitch:" else None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
# fmt: off # fmt: off
@ -523,7 +523,7 @@ class Streaming(BaseActivity):
} }
# fmt: on # fmt: on
if self.details: if self.details:
ret['details'] = self.details ret["details"] = self.details
return ret return ret
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
@ -559,17 +559,17 @@ class Spotify:
Returns the string 'Spotify'. Returns the string 'Spotify'.
""" """
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at') __slots__ = ("_state", "_details", "_timestamps", "_assets", "_party", "_sync_id", "_session_id", "_created_at")
def __init__(self, **data): def __init__(self, **data):
self._state: str = data.pop('state', '') self._state: str = data.pop("state", "")
self._details: str = data.pop('details', '') self._details: str = data.pop("details", "")
self._timestamps: Dict[str, int] = data.pop('timestamps', {}) self._timestamps: Dict[str, int] = data.pop("timestamps", {})
self._assets: ActivityAssets = data.pop('assets', {}) self._assets: ActivityAssets = data.pop("assets", {})
self._party: ActivityParty = data.pop('party', {}) self._party: ActivityParty = data.pop("party", {})
self._sync_id: str = data.pop('sync_id') self._sync_id: str = data.pop("sync_id")
self._session_id: str = data.pop('session_id') self._session_id: str = data.pop("session_id")
self._created_at: Optional[float] = data.pop('created_at', None) self._created_at: Optional[float] = data.pop("created_at", None)
@property @property
def type(self) -> ActivityType: def type(self) -> ActivityType:
@ -604,21 +604,21 @@ class Spotify:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { return {
'flags': 48, # SYNC | PLAY "flags": 48, # SYNC | PLAY
'name': 'Spotify', "name": "Spotify",
'assets': self._assets, "assets": self._assets,
'party': self._party, "party": self._party,
'sync_id': self._sync_id, "sync_id": self._sync_id,
'session_id': self._session_id, "session_id": self._session_id,
'timestamps': self._timestamps, "timestamps": self._timestamps,
'details': self._details, "details": self._details,
'state': self._state, "state": self._state,
} }
@property @property
def name(self) -> str: def name(self) -> str:
""":class:`str`: The activity's name. This will always return "Spotify".""" """:class:`str`: The activity's name. This will always return "Spotify"."""
return 'Spotify' return "Spotify"
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return ( return (
@ -635,10 +635,10 @@ class Spotify:
return hash(self._session_id) return hash(self._session_id)
def __str__(self) -> str: def __str__(self) -> str:
return 'Spotify' return "Spotify"
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>' return f"<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>"
@property @property
def title(self) -> str: def title(self) -> str:
@ -648,7 +648,7 @@ class Spotify:
@property @property
def artists(self) -> List[str]: def artists(self) -> List[str]:
"""List[:class:`str`]: The artists of the song being played.""" """List[:class:`str`]: The artists of the song being played."""
return self._state.split('; ') return self._state.split("; ")
@property @property
def artist(self) -> str: def artist(self) -> str:
@ -662,16 +662,16 @@ class Spotify:
@property @property
def album(self) -> str: def album(self) -> str:
""":class:`str`: The album that the song being played belongs to.""" """:class:`str`: The album that the song being played belongs to."""
return self._assets.get('large_text', '') return self._assets.get("large_text", "")
@property @property
def album_cover_url(self) -> str: def album_cover_url(self) -> str:
""":class:`str`: The album cover image URL from Spotify's CDN.""" """:class:`str`: The album cover image URL from Spotify's CDN."""
large_image = self._assets.get('large_image', '') large_image = self._assets.get("large_image", "")
if large_image[:8] != 'spotify:': if large_image[:8] != "spotify:":
return '' return ""
album_image_id = large_image[8:] album_image_id = large_image[8:]
return 'https://i.scdn.co/image/' + album_image_id return "https://i.scdn.co/image/" + album_image_id
@property @property
def track_id(self) -> str: def track_id(self) -> str:
@ -684,17 +684,17 @@ class Spotify:
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return f'https://open.spotify.com/track/{self.track_id}' return f"https://open.spotify.com/track/{self.track_id}"
@property @property
def start(self) -> datetime.datetime: def start(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user started playing this song in UTC.""" """:class:`datetime.datetime`: When the user started playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) return datetime.datetime.fromtimestamp(self._timestamps["start"] / 1000, tz=datetime.timezone.utc)
@property @property
def end(self) -> datetime.datetime: def end(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user will stop playing this song in UTC.""" """:class:`datetime.datetime`: When the user will stop playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) return datetime.datetime.fromtimestamp(self._timestamps["end"] / 1000, tz=datetime.timezone.utc)
@property @property
def duration(self) -> datetime.timedelta: def duration(self) -> datetime.timedelta:
@ -704,7 +704,7 @@ class Spotify:
@property @property
def party_id(self) -> str: def party_id(self) -> str:
""":class:`str`: The party ID of the listening party.""" """:class:`str`: The party ID of the listening party."""
return self._party.get('id', '') return self._party.get("id", "")
class CustomActivity(BaseActivity): class CustomActivity(BaseActivity):
@ -738,13 +738,13 @@ class CustomActivity(BaseActivity):
The emoji to pass to the activity, if any. The emoji to pass to the activity, if any.
""" """
__slots__ = ('name', 'emoji', 'state') __slots__ = ("name", "emoji", "state")
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any): def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any):
super().__init__(**extra) super().__init__(**extra)
self.name: Optional[str] = name self.name: Optional[str] = name
self.state: Optional[str] = extra.pop('state', None) self.state: Optional[str] = extra.pop("state", None)
if self.name == 'Custom Status': if self.name == "Custom Status":
self.name = self.state self.name = self.state
self.emoji: Optional[PartialEmoji] self.emoji: Optional[PartialEmoji]
@ -757,7 +757,7 @@ class CustomActivity(BaseActivity):
elif isinstance(emoji, PartialEmoji): elif isinstance(emoji, PartialEmoji):
self.emoji = emoji self.emoji = emoji
else: else:
raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.') raise TypeError(f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.")
@property @property
def type(self) -> ActivityType: def type(self) -> ActivityType:
@ -770,18 +770,18 @@ class CustomActivity(BaseActivity):
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
if self.name == self.state: if self.name == self.state:
o = { o = {
'type': ActivityType.custom.value, "type": ActivityType.custom.value,
'state': self.name, "state": self.name,
'name': 'Custom Status', "name": "Custom Status",
} }
else: else:
o = { o = {
'type': ActivityType.custom.value, "type": ActivityType.custom.value,
'name': self.name, "name": self.name,
} }
if self.emoji: if self.emoji:
o['emoji'] = self.emoji.to_dict() o["emoji"] = self.emoji.to_dict()
return o return o
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
@ -796,47 +796,50 @@ class CustomActivity(BaseActivity):
def __str__(self) -> str: def __str__(self) -> str:
if self.emoji: if self.emoji:
if self.name: if self.name:
return f'{self.emoji} {self.name}' return f"{self.emoji} {self.name}"
return str(self.emoji) return str(self.emoji)
else: else:
return str(self.name) return str(self.name)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>' return f"<CustomActivity name={self.name!r} emoji={self.emoji!r}>"
ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify] ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify]
@overload @overload
def create_activity(data: ActivityPayload) -> ActivityTypes: def create_activity(data: ActivityPayload) -> ActivityTypes:
... ...
@overload @overload
def create_activity(data: None) -> None: def create_activity(data: None) -> None:
... ...
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]: def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
if not data: if not data:
return None return None
game_type = try_enum(ActivityType, data.get('type', -1)) game_type = try_enum(ActivityType, data.get("type", -1))
if game_type is ActivityType.playing: if game_type is ActivityType.playing:
if 'application_id' in data or 'session_id' in data: if "application_id" in data or "session_id" in data:
return Activity(**data) return Activity(**data)
return Game(**data) return Game(**data)
elif game_type is ActivityType.custom: elif game_type is ActivityType.custom:
try: try:
name = data.pop('name') name = data.pop("name")
except KeyError: except KeyError:
return Activity(**data) return Activity(**data)
else: else:
# we removed the name key from data already # we removed the name key from data already
return CustomActivity(name=name, **data) # type: ignore return CustomActivity(name=name, **data) # type: ignore
elif game_type is ActivityType.streaming: elif game_type is ActivityType.streaming:
if 'url' in data: if "url" in data:
# the url won't be None here # the url won't be None here
return Streaming(**data) # type: ignore return Streaming(**data) # type: ignore
return Activity(**data) return Activity(**data)
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data: elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data:
return Spotify(**data) return Spotify(**data)
return Activity(**data) return Activity(**data)

View File

@ -40,8 +40,8 @@ if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
__all__ = ( __all__ = (
'AppInfo', "AppInfo",
'PartialAppInfo', "PartialAppInfo",
) )
@ -115,58 +115,58 @@ class AppInfo:
""" """
__slots__ = ( __slots__ = (
'_state', "_state",
'description', "description",
'id', "id",
'name', "name",
'rpc_origins', "rpc_origins",
'bot_public', "bot_public",
'bot_require_code_grant', "bot_require_code_grant",
'owner', "owner",
'_icon', "_icon",
'summary', "summary",
'verify_key', "verify_key",
'team', "team",
'guild_id', "guild_id",
'primary_sku_id', "primary_sku_id",
'slug', "slug",
'_cover_image', "_cover_image",
'terms_of_service_url', "terms_of_service_url",
'privacy_policy_url', "privacy_policy_url",
) )
def __init__(self, state: ConnectionState, data: AppInfoPayload): def __init__(self, state: ConnectionState, data: AppInfoPayload):
from .team import Team from .team import Team
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.name: str = data['name'] self.name: str = data["name"]
self.description: str = data['description'] self.description: str = data["description"]
self._icon: Optional[str] = data['icon'] self._icon: Optional[str] = data["icon"]
self.rpc_origins: List[str] = data['rpc_origins'] self.rpc_origins: List[str] = data["rpc_origins"]
self.bot_public: bool = data['bot_public'] self.bot_public: bool = data["bot_public"]
self.bot_require_code_grant: bool = data['bot_require_code_grant'] self.bot_require_code_grant: bool = data["bot_require_code_grant"]
self.owner: User = state.create_user(data['owner']) self.owner: User = state.create_user(data["owner"])
team: Optional[TeamPayload] = data.get('team') team: Optional[TeamPayload] = data.get("team")
self.team: Optional[Team] = Team(state, team) if team else None self.team: Optional[Team] = Team(state, team) if team else None
self.summary: str = data['summary'] self.summary: str = data["summary"]
self.verify_key: str = data['verify_key'] self.verify_key: str = data["verify_key"]
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id') self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, "primary_sku_id")
self.slug: Optional[str] = data.get('slug') self.slug: Optional[str] = data.get("slug")
self._cover_image: Optional[str] = data.get('cover_image') self._cover_image: Optional[str] = data.get("cover_image")
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url') self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url') self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} ' f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
f'description={self.description!r} public={self.bot_public} ' f"description={self.description!r} public={self.bot_public} "
f'owner={self.owner!r}>' f"owner={self.owner!r}>"
) )
@property @property
@ -174,7 +174,7 @@ class AppInfo:
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
if self._icon is None: if self._icon is None:
return None return None
return Asset._from_icon(self._state, self.id, self._icon, path='app') return Asset._from_icon(self._state, self.id, self._icon, path="app")
@property @property
def cover_image(self) -> Optional[Asset]: def cover_image(self) -> Optional[Asset]:
@ -195,6 +195,7 @@ class AppInfo:
""" """
return self._state._get_guild(self.guild_id) return self._state._get_guild(self.guild_id)
class PartialAppInfo: class PartialAppInfo:
"""Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite` """Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite`
@ -222,26 +223,37 @@ class PartialAppInfo:
The application's privacy policy URL, if set. The application's privacy policy URL, if set.
""" """
__slots__ = ('_state', 'id', 'name', 'description', 'rpc_origins', 'summary', 'verify_key', 'terms_of_service_url', 'privacy_policy_url', '_icon') __slots__ = (
"_state",
"id",
"name",
"description",
"rpc_origins",
"summary",
"verify_key",
"terms_of_service_url",
"privacy_policy_url",
"_icon",
)
def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload): def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload):
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.name: str = data['name'] self.name: str = data["name"]
self._icon: Optional[str] = data.get('icon') self._icon: Optional[str] = data.get("icon")
self.description: str = data['description'] self.description: str = data["description"]
self.rpc_origins: Optional[List[str]] = data.get('rpc_origins') self.rpc_origins: Optional[List[str]] = data.get("rpc_origins")
self.summary: str = data['summary'] self.summary: str = data["summary"]
self.verify_key: str = data['verify_key'] self.verify_key: str = data["verify_key"]
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url') self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url') self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>' return f"<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>"
@property @property
def icon(self) -> Optional[Asset]: def icon(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
if self._icon is None: if self._icon is None:
return None return None
return Asset._from_icon(self._state, self.id, self._icon, path='app') return Asset._from_icon(self._state, self.id, self._icon, path="app")

View File

@ -33,13 +33,11 @@ from . import utils
import yarl import yarl
__all__ = ( __all__ = ("Asset",)
'Asset',
)
if TYPE_CHECKING: if TYPE_CHECKING:
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"]
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif'] ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"]
VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
@ -47,6 +45,7 @@ VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
MISSING = utils.MISSING MISSING = utils.MISSING
class AssetMixin: class AssetMixin:
url: str url: str
_state: Optional[Any] _state: Optional[Any]
@ -71,7 +70,7 @@ class AssetMixin:
The content of the asset. The content of the asset.
""" """
if self._state is None: if self._state is None:
raise DiscordException('Invalid state (no ConnectionState provided)') raise DiscordException("Invalid state (no ConnectionState provided)")
return await self._state.http.get_from_cdn(self.url) return await self._state.http.get_from_cdn(self.url)
@ -112,7 +111,7 @@ class AssetMixin:
fp.seek(0) fp.seek(0)
return written return written
else: else:
with open(fp, 'wb') as f: with open(fp, "wb") as f:
return f.write(data) return f.write(data)
@ -143,13 +142,13 @@ class Asset(AssetMixin):
""" """
__slots__: Tuple[str, ...] = ( __slots__: Tuple[str, ...] = (
'_state', "_state",
'_url', "_url",
'_animated', "_animated",
'_key', "_key",
) )
BASE = 'https://cdn.discordapp.com' BASE = "https://cdn.discordapp.com"
def __init__(self, state, *, url: str, key: str, animated: bool = False): def __init__(self, state, *, url: str, key: str, animated: bool = False):
self._state = state self._state = state
@ -161,26 +160,26 @@ class Asset(AssetMixin):
def _from_default_avatar(cls, state, index: int) -> Asset: def _from_default_avatar(cls, state, index: int) -> Asset:
return cls( return cls(
state, state,
url=f'{cls.BASE}/embed/avatars/{index}.png', url=f"{cls.BASE}/embed/avatars/{index}.png",
key=str(index), key=str(index),
animated=False, animated=False,
) )
@classmethod @classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
animated = avatar.startswith('a_') animated = avatar.startswith("a_")
format = 'gif' if animated else 'png' format = "gif" if animated else "png"
return cls( return cls(
state, state,
url=f'{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024', url=f"{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024",
key=avatar, key=avatar,
animated=animated, animated=animated,
) )
@classmethod @classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
animated = avatar.startswith('a_') animated = avatar.startswith("a_")
format = 'gif' if animated else 'png' format = "gif" if animated else "png"
return cls( return cls(
state, state,
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024", url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024",
@ -192,7 +191,7 @@ class Asset(AssetMixin):
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
return cls( return cls(
state, state,
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024', url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024",
key=icon_hash, key=icon_hash,
animated=False, animated=False,
) )
@ -201,7 +200,7 @@ class Asset(AssetMixin):
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
return cls( return cls(
state, state,
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024', url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024",
key=cover_image_hash, key=cover_image_hash,
animated=False, animated=False,
) )
@ -210,18 +209,18 @@ class Asset(AssetMixin):
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
return cls( return cls(
state, state,
url=f'{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024', url=f"{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024",
key=image, key=image,
animated=False, animated=False,
) )
@classmethod @classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
animated = icon_hash.startswith('a_') animated = icon_hash.startswith("a_")
format = 'gif' if animated else 'png' format = "gif" if animated else "png"
return cls( return cls(
state, state,
url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024', url=f"{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024",
key=icon_hash, key=icon_hash,
animated=animated, animated=animated,
) )
@ -230,20 +229,20 @@ class Asset(AssetMixin):
def _from_sticker_banner(cls, state, banner: int) -> Asset: def _from_sticker_banner(cls, state, banner: int) -> Asset:
return cls( return cls(
state, state,
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png', url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png",
key=str(banner), key=str(banner),
animated=False, animated=False,
) )
@classmethod @classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
animated = banner_hash.startswith('a_') animated = banner_hash.startswith("a_")
format = 'gif' if animated else 'png' format = "gif" if animated else "png"
return cls( return cls(
state, state,
url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512', url=f"{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512",
key=banner_hash, key=banner_hash,
animated=animated animated=animated,
) )
def __str__(self) -> str: def __str__(self) -> str:
@ -253,8 +252,8 @@ class Asset(AssetMixin):
return len(self._url) return len(self._url)
def __repr__(self): def __repr__(self):
shorten = self._url.replace(self.BASE, '') shorten = self._url.replace(self.BASE, "")
return f'<Asset url={shorten!r}>' return f"<Asset url={shorten!r}>"
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, Asset) and self._url == other._url return isinstance(other, Asset) and self._url == other._url
@ -312,20 +311,21 @@ class Asset(AssetMixin):
if format is not MISSING: if format is not MISSING:
if self._animated: if self._animated:
if format not in VALID_ASSET_FORMATS: if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}') raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}")
else: url = url.with_path(f"{path}.{format}")
elif static_format is MISSING:
if format not in VALID_STATIC_FORMATS: if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}') raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}")
url = url.with_path(f'{path}.{format}') url = url.with_path(f"{path}.{format}")
if static_format is not MISSING and not self._animated: if static_format is not MISSING and not self._animated:
if static_format not in VALID_STATIC_FORMATS: if static_format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'static_format must be one of {VALID_STATIC_FORMATS}') raise InvalidArgument(f"static_format must be one of {VALID_STATIC_FORMATS}")
url = url.with_path(f'{path}.{static_format}') url = url.with_path(f"{path}.{static_format}")
if size is not MISSING: if size is not MISSING:
if not utils.valid_icon_size(size): if not utils.valid_icon_size(size):
raise InvalidArgument('size must be a power of 2 between 16 and 4096') raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url = url.with_query(size=size) url = url.with_query(size=size)
else: else:
url = url.with_query(url.raw_query_string) url = url.with_query(url.raw_query_string)
@ -352,7 +352,7 @@ class Asset(AssetMixin):
The new updated asset. The new updated asset.
""" """
if not utils.valid_icon_size(size): if not utils.valid_icon_size(size):
raise InvalidArgument('size must be a power of 2 between 16 and 4096') raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url = str(yarl.URL(self._url).with_query(size=size)) url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated) return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
@ -378,14 +378,14 @@ class Asset(AssetMixin):
if self._animated: if self._animated:
if format not in VALID_ASSET_FORMATS: if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}') raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}")
else: else:
if format not in VALID_STATIC_FORMATS: if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}') raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}")
url = yarl.URL(self._url) url = yarl.URL(self._url)
path, _ = os.path.splitext(url.path) path, _ = os.path.splitext(url.path)
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string)) url = str(url.with_path(f"{path}.{format}").with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated) return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset: def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset:

View File

@ -35,9 +35,9 @@ from .object import Object
from .permissions import PermissionOverwrite, Permissions from .permissions import PermissionOverwrite, Permissions
__all__ = ( __all__ = (
'AuditLogDiff', "AuditLogDiff",
'AuditLogChanges', "AuditLogChanges",
'AuditLogEntry', "AuditLogEntry",
) )
@ -85,6 +85,7 @@ def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Uni
return None return None
return entry._get_member(int(data)) return entry._get_member(int(data))
def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]: def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]:
if data is None: if data is None:
return None return None
@ -96,16 +97,16 @@ def _transform_overwrites(
) -> List[Tuple[Object, PermissionOverwrite]]: ) -> List[Tuple[Object, PermissionOverwrite]]:
overwrites = [] overwrites = []
for elem in data: for elem in data:
allow = Permissions(int(elem['allow'])) allow = Permissions(int(elem["allow"]))
deny = Permissions(int(elem['deny'])) deny = Permissions(int(elem["deny"]))
ow = PermissionOverwrite.from_pair(allow, deny) ow = PermissionOverwrite.from_pair(allow, deny)
ow_type = elem['type'] ow_type = elem["type"]
ow_id = int(elem['id']) ow_id = int(elem["id"])
target = None target = None
if ow_type == '0': if ow_type == "0":
target = entry.guild.get_role(ow_id) target = entry.guild.get_role(ow_id)
elif ow_type == '1': elif ow_type == "1":
target = entry._get_member(ow_id) target = entry._get_member(ow_id)
if target is None: if target is None:
@ -137,7 +138,7 @@ def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]
return _transform return _transform
T = TypeVar('T', bound=enums.Enum) T = TypeVar("T", bound=enums.Enum)
def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]: def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
@ -146,12 +147,14 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
return _transform return _transform
def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]: def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]:
if entry.action.name.startswith('sticker_'): if entry.action.name.startswith("sticker_"):
return enums.try_enum(enums.StickerType, data) return enums.try_enum(enums.StickerType, data)
else: else:
return enums.try_enum(enums.ChannelType, data) return enums.try_enum(enums.ChannelType, data)
class AuditLogDiff: class AuditLogDiff:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.__dict__) return len(self.__dict__)
@ -160,8 +163,8 @@ class AuditLogDiff:
yield from self.__dict__.items() yield from self.__dict__.items()
def __repr__(self) -> str: def __repr__(self) -> str:
values = ' '.join('%s=%r' % item for item in self.__dict__.items()) values = " ".join("%s=%r" % item for item in self.__dict__.items())
return f'<AuditLogDiff {values}>' return f"<AuditLogDiff {values}>"
if TYPE_CHECKING: if TYPE_CHECKING:
@ -217,14 +220,14 @@ class AuditLogChanges:
self.after = AuditLogDiff() self.after = AuditLogDiff()
for elem in data: for elem in data:
attr = elem['key'] attr = elem["key"]
# special cases for role add/remove # special cases for role add/remove
if attr == '$add': if attr == "$add":
self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore self._handle_role(self.before, self.after, entry, elem["new_value"]) # type: ignore
continue continue
elif attr == '$remove': elif attr == "$remove":
self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore self._handle_role(self.after, self.before, entry, elem["new_value"]) # type: ignore
continue continue
try: try:
@ -238,7 +241,7 @@ class AuditLogChanges:
transformer: Optional[Transformer] transformer: Optional[Transformer]
try: try:
before = elem['old_value'] before = elem["old_value"]
except KeyError: except KeyError:
before = None before = None
else: else:
@ -248,7 +251,7 @@ class AuditLogChanges:
setattr(self.before, attr, before) setattr(self.before, attr, before)
try: try:
after = elem['new_value'] after = elem["new_value"]
except KeyError: except KeyError:
after = None after = None
else: else:
@ -258,34 +261,36 @@ class AuditLogChanges:
setattr(self.after, attr, after) setattr(self.after, attr, after)
# add an alias # add an alias
if hasattr(self.after, 'colour'): if hasattr(self.after, "colour"):
self.after.color = self.after.colour self.after.color = self.after.colour
self.before.color = self.before.colour self.before.color = self.before.colour
if hasattr(self.after, 'expire_behavior'): if hasattr(self.after, "expire_behavior"):
self.after.expire_behaviour = self.after.expire_behavior self.after.expire_behaviour = self.after.expire_behavior
self.before.expire_behaviour = self.before.expire_behavior self.before.expire_behaviour = self.before.expire_behavior
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<AuditLogChanges before={self.before!r} after={self.after!r}>' return f"<AuditLogChanges before={self.before!r} after={self.after!r}>"
def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None: def _handle_role(
if not hasattr(first, 'roles'): self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]
setattr(first, 'roles', []) ) -> None:
if not hasattr(first, "roles"):
setattr(first, "roles", [])
data = [] data = []
g: Guild = entry.guild # type: ignore g: Guild = entry.guild # type: ignore
for e in elem: for e in elem:
role_id = int(e['id']) role_id = int(e["id"])
role = g.get_role(role_id) role = g.get_role(role_id)
if role is None: if role is None:
role = Object(id=role_id) role = Object(id=role_id)
role.name = e['name'] # type: ignore role.name = e["name"] # type: ignore
data.append(role) data.append(role)
setattr(second, 'roles', data) setattr(second, "roles", data)
class _AuditLogProxyMemberPrune: class _AuditLogProxyMemberPrune:
@ -330,6 +335,10 @@ class AuditLogEntry(Hashable):
Returns the entry's hash. Returns the entry's hash.
.. describe:: int(x)
Returns the entry's ID.
.. versionchanged:: 1.7 .. versionchanged:: 1.7
Audit log entries are now comparable and hashable. Audit log entries are now comparable and hashable.
@ -361,56 +370,56 @@ class AuditLogEntry(Hashable):
self._from_data(data) self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None: def _from_data(self, data: AuditLogEntryPayload) -> None:
self.action = enums.try_enum(enums.AuditLogAction, data['action_type']) self.action = enums.try_enum(enums.AuditLogAction, data["action_type"])
self.id = int(data['id']) self.id = int(data["id"])
# this key is technically not usually present # this key is technically not usually present
self.reason = data.get('reason') self.reason = data.get("reason")
self.extra = data.get('options') self.extra = data.get("options")
if isinstance(self.action, enums.AuditLogAction) and self.extra: if isinstance(self.action, enums.AuditLogAction) and self.extra:
if self.action is enums.AuditLogAction.member_prune: if self.action is enums.AuditLogAction.member_prune:
# member prune has two keys with useful information # member prune has two keys with useful information
self.extra: _AuditLogProxyMemberPrune = type( self.extra: _AuditLogProxyMemberPrune = type(
'_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()} "_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()}
)() )()
elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete: elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete:
channel_id = int(self.extra['channel_id']) channel_id = int(self.extra["channel_id"])
elems = { elems = {
'count': int(self.extra['count']), "count": int(self.extra["count"]),
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), "channel": self.guild.get_channel(channel_id) or Object(id=channel_id),
} }
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type('_AuditLogProxy', (), elems)() self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type("_AuditLogProxy", (), elems)()
elif self.action is enums.AuditLogAction.member_disconnect: elif self.action is enums.AuditLogAction.member_disconnect:
# The member disconnect action has a dict with some information # The member disconnect action has a dict with some information
elems = { elems = {
'count': int(self.extra['count']), "count": int(self.extra["count"]),
} }
self.extra: _AuditLogProxyMemberDisconnect = type('_AuditLogProxy', (), elems)() self.extra: _AuditLogProxyMemberDisconnect = type("_AuditLogProxy", (), elems)()
elif self.action.name.endswith('pin'): elif self.action.name.endswith("pin"):
# the pin actions have a dict with some information # the pin actions have a dict with some information
channel_id = int(self.extra['channel_id']) channel_id = int(self.extra["channel_id"])
elems = { elems = {
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), "channel": self.guild.get_channel(channel_id) or Object(id=channel_id),
'message_id': int(self.extra['message_id']), "message_id": int(self.extra["message_id"]),
} }
self.extra: _AuditLogProxyPinAction = type('_AuditLogProxy', (), elems)() self.extra: _AuditLogProxyPinAction = type("_AuditLogProxy", (), elems)()
elif self.action.name.startswith('overwrite_'): elif self.action.name.startswith("overwrite_"):
# the overwrite_ actions have a dict with some information # the overwrite_ actions have a dict with some information
instance_id = int(self.extra['id']) instance_id = int(self.extra["id"])
the_type = self.extra.get('type') the_type = self.extra.get("type")
if the_type == '1': if the_type == "1":
self.extra = self._get_member(instance_id) self.extra = self._get_member(instance_id)
elif the_type == '0': elif the_type == "0":
role = self.guild.get_role(instance_id) role = self.guild.get_role(instance_id)
if role is None: if role is None:
role = Object(id=instance_id) role = Object(id=instance_id)
role.name = self.extra.get('role_name') # type: ignore role.name = self.extra.get("role_name") # type: ignore
self.extra: Role = role self.extra: Role = role
elif self.action.name.startswith('stage_instance'): elif self.action.name.startswith("stage_instance"):
channel_id = int(self.extra['channel_id']) channel_id = int(self.extra["channel_id"])
elems = {'channel': self.guild.get_channel(channel_id) or Object(id=channel_id)} elems = {"channel": self.guild.get_channel(channel_id) or Object(id=channel_id)}
self.extra: _AuditLogProxyStageInstanceAction = type('_AuditLogProxy', (), elems)() self.extra: _AuditLogProxyStageInstanceAction = type("_AuditLogProxy", (), elems)()
# fmt: off # fmt: off
self.extra: Union[ self.extra: Union[
@ -429,16 +438,16 @@ class AuditLogEntry(Hashable):
# where new_value and old_value are not guaranteed to be there depending # where new_value and old_value are not guaranteed to be there depending
# on the action type, so let's just fetch it for now and only turn it # on the action type, so let's just fetch it for now and only turn it
# into meaningful data when requested # into meaningful data when requested
self._changes = data.get('changes', []) self._changes = data.get("changes", [])
self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) # type: ignore self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore
self._target_id = utils._get_as_snowflake(data, 'target_id') self._target_id = utils._get_as_snowflake(data, "target_id")
def _get_member(self, user_id: int) -> Union[Member, User, None]: def _get_member(self, user_id: int) -> Union[Member, User, None]:
return self.guild.get_member(user_id) or self._users.get(user_id) return self.guild.get_member(user_id) or self._users.get(user_id)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>' return f"<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>"
@utils.cached_property @utils.cached_property
def created_at(self) -> datetime.datetime: def created_at(self) -> datetime.datetime:
@ -446,9 +455,13 @@ class AuditLogEntry(Hashable):
return utils.snowflake_time(self.id) return utils.snowflake_time(self.id)
@utils.cached_property @utils.cached_property
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]: def target(
self,
) -> Union[
Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None
]:
try: try:
converter = getattr(self, '_convert_target_' + self.action.target_type) converter = getattr(self, "_convert_target_" + self.action.target_type)
except AttributeError: except AttributeError:
return Object(id=self._target_id) return Object(id=self._target_id)
else: else:
@ -494,11 +507,11 @@ class AuditLogEntry(Hashable):
changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after
fake_payload = { fake_payload = {
'max_age': changeset.max_age, "max_age": changeset.max_age,
'max_uses': changeset.max_uses, "max_uses": changeset.max_uses,
'code': changeset.code, "code": changeset.code,
'temporary': changeset.temporary, "temporary": changeset.temporary,
'uses': changeset.uses, "uses": changeset.uses,
} }
obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore

View File

@ -29,11 +29,10 @@ import time
import random import random
from typing import Callable, Generic, Literal, TypeVar, overload, Union from typing import Callable, Generic, Literal, TypeVar, overload, Union
T = TypeVar('T', bool, Literal[True], Literal[False]) T = TypeVar("T", bool, Literal[True], Literal[False])
__all__ = ("ExponentialBackoff",)
__all__ = (
'ExponentialBackoff',
)
class ExponentialBackoff(Generic[T]): class ExponentialBackoff(Generic[T]):
"""An implementation of the exponential backoff algorithm """An implementation of the exponential backoff algorithm
@ -69,7 +68,7 @@ class ExponentialBackoff(Generic[T]):
rand = random.Random() rand = random.Random()
rand.seed() rand.seed()
self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore
@overload @overload
def delay(self: ExponentialBackoff[Literal[False]]) -> float: def delay(self: ExponentialBackoff[Literal[False]]) -> float:

View File

@ -57,14 +57,14 @@ from .threads import Thread
from .iterators import ArchivedThreadIterator from .iterators import ArchivedThreadIterator
__all__ = ( __all__ = (
'TextChannel', "TextChannel",
'VoiceChannel', "VoiceChannel",
'StageChannel', "StageChannel",
'DMChannel', "DMChannel",
'CategoryChannel', "CategoryChannel",
'StoreChannel', "StoreChannel",
'GroupChannel', "GroupChannel",
'PartialMessageable', "PartialMessageable",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -115,6 +115,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
Returns the channel's name. Returns the channel's name.
.. describe:: int(x)
Returns the channel's ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -151,51 +155,51 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
""" """
__slots__ = ( __slots__ = (
'name', "name",
'id', "id",
'guild', "guild",
'topic', "topic",
'_state', "_state",
'nsfw', "nsfw",
'category_id', "category_id",
'position', "position",
'slowmode_delay', "slowmode_delay",
'_overwrites', "_overwrites",
'_type', "_type",
'last_message_id', "last_message_id",
'default_auto_archive_duration', "default_auto_archive_duration",
) )
def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload): def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload):
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self._type: int = data['type'] self._type: int = data["type"]
self._update(guild, data) self._update(guild, data)
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = [ attrs = [
('id', self.id), ("id", self.id),
('name', self.name), ("name", self.name),
('position', self.position), ("position", self.position),
('nsfw', self.nsfw), ("nsfw", self.nsfw),
('news', self.is_news()), ("news", self.is_news()),
('category_id', self.category_id), ("category_id", self.category_id),
] ]
joined = ' '.join('%s=%r' % t for t in attrs) joined = " ".join("%s=%r" % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>' return f"<{self.__class__.__name__} {joined}>"
def _update(self, guild: Guild, data: TextChannelPayload) -> None: def _update(self, guild: Guild, data: TextChannelPayload) -> None:
self.guild: Guild = guild self.guild: Guild = guild
self.name: str = data['name'] self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.topic: Optional[str] = data.get('topic') self.topic: Optional[str] = data.get("topic")
self.position: int = data['position'] self.position: int = data["position"]
self.nsfw: bool = data.get('nsfw', False) self.nsfw: bool = data.get("nsfw", False)
# Does this need coercion into `int`? No idea yet. # Does this need coercion into `int`? No idea yet.
self.slowmode_delay: int = data.get('rate_limit_per_user', 0) self.slowmode_delay: int = data.get("rate_limit_per_user", 0)
self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440) self.default_auto_archive_duration: ThreadArchiveDuration = data.get("default_auto_archive_duration", 1440)
self._type: int = data.get('type', self._type) self._type: int = data.get("type", self._type)
self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') self.last_message_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id")
self._fill_overwrites(data) self._fill_overwrites(data)
async def _get_channel(self): async def _get_channel(self):
@ -224,6 +228,16 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
"""List[:class:`Member`]: Returns all members that can see this channel.""" """List[:class:`Member`]: Returns all members that can see this channel."""
return [m for m in self.guild.members if self.permissions_for(m).read_messages] return [m for m in self.guild.members if self.permissions_for(m).read_messages]
@property
def bots(self) -> List[Member]:
"""List[:class:`Member`]: Returns all bots that can see this channel."""
return [m for m in self.guild.members if m.bot and self.permissions_for(m).read_messages]
@property
def humans(self) -> List[Member]:
"""List[:class:`Member`]: Returns all human members that can see this channel."""
return [m for m in self.guild.members if not m.bot and self.permissions_for(m).read_messages]
@property @property
def threads(self) -> List[Thread]: def threads(self) -> List[Thread]:
"""List[:class:`Thread`]: Returns all the threads that you can see. """List[:class:`Thread`]: Returns all the threads that you can see.
@ -357,7 +371,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
@utils.copy_doc(discord.abc.GuildChannel.clone) @utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel: async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel:
return await self._clone_impl( return await self._clone_impl(
{'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason {"topic": self.topic, "nsfw": self.nsfw, "rate_limit_per_user": self.slowmode_delay},
name=name,
reason=reason,
) )
async def delete_messages(self, messages: Iterable[Snowflake]) -> None: async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
@ -404,7 +420,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
return return
if len(messages) > 100: if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages') raise ClientException("Can only bulk delete messages up to 100 messages")
message_ids: SnowflakeList = [m.id for m in messages] message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids) await self._state.http.delete_messages(self.id, message_ids)
@ -544,7 +560,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
data = await self._state.http.channel_webhooks(self.id) data = await self._state.http.channel_webhooks(self.id)
return [Webhook.from_state(d, state=self._state) for d in data] return [Webhook.from_state(d, state=self._state) for d in data]
async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook: async def create_webhook(
self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None
) -> Webhook:
"""|coro| """|coro|
Creates a webhook for this channel. Creates a webhook for this channel.
@ -621,10 +639,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
""" """
if not self.is_news(): if not self.is_news():
raise ClientException('The channel must be a news channel.') raise ClientException("The channel must be a news channel.")
if not isinstance(destination, TextChannel): if not isinstance(destination, TextChannel):
raise InvalidArgument(f'Expected TextChannel received {destination.__class__.__name__}') raise InvalidArgument(f"Expected TextChannel received {destination.__class__.__name__}")
from .webhook import Webhook from .webhook import Webhook
@ -788,40 +806,40 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable):
class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable):
__slots__ = ( __slots__ = (
'name', "name",
'id', "id",
'guild', "guild",
'bitrate', "bitrate",
'user_limit', "user_limit",
'_state', "_state",
'position', "position",
'_overwrites', "_overwrites",
'category_id', "category_id",
'rtc_region', "rtc_region",
'video_quality_mode', "video_quality_mode",
) )
def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]): def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]):
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self._update(guild, data) self._update(guild, data)
def _get_voice_client_key(self) -> Tuple[int, str]: def _get_voice_client_key(self) -> Tuple[int, str]:
return self.guild.id, 'guild_id' return self.guild.id, "guild_id"
def _get_voice_state_pair(self) -> Tuple[int, int]: def _get_voice_state_pair(self) -> Tuple[int, int]:
return self.guild.id, self.id return self.guild.id, self.id
def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None:
self.guild = guild self.guild = guild
self.name: str = data['name'] self.name: str = data["name"]
rtc = data.get('rtc_region') rtc = data.get("rtc_region")
self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None
self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1))
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.position: int = data['position'] self.position: int = data["position"]
self.bitrate: int = data.get('bitrate') self.bitrate: int = data.get("bitrate")
self.user_limit: int = data.get('user_limit') self.user_limit: int = data.get("user_limit")
self._fill_overwrites(data) self._fill_overwrites(data)
@property @property
@ -929,17 +947,17 @@ class VoiceChannel(VocalGuildChannel):
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = [ attrs = [
('id', self.id), ("id", self.id),
('name', self.name), ("name", self.name),
('rtc_region', self.rtc_region), ("rtc_region", self.rtc_region),
('position', self.position), ("position", self.position),
('bitrate', self.bitrate), ("bitrate", self.bitrate),
('video_quality_mode', self.video_quality_mode), ("video_quality_mode", self.video_quality_mode),
('user_limit', self.user_limit), ("user_limit", self.user_limit),
('category_id', self.category_id), ("category_id", self.category_id),
] ]
joined = ' '.join('%s=%r' % t for t in attrs) joined = " ".join("%s=%r" % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>' return f"<{self.__class__.__name__} {joined}>"
@property @property
def type(self) -> ChannelType: def type(self) -> ChannelType:
@ -948,7 +966,9 @@ class VoiceChannel(VocalGuildChannel):
@utils.copy_doc(discord.abc.GuildChannel.clone) @utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel: async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel:
return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason) return await self._clone_impl(
{"bitrate": self.bitrate, "user_limit": self.user_limit}, name=name, reason=reason
)
@overload @overload
async def edit( async def edit(
@ -1089,26 +1109,26 @@ class StageChannel(VocalGuildChannel):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
__slots__ = ('topic',) __slots__ = ("topic",)
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = [ attrs = [
('id', self.id), ("id", self.id),
('name', self.name), ("name", self.name),
('topic', self.topic), ("topic", self.topic),
('rtc_region', self.rtc_region), ("rtc_region", self.rtc_region),
('position', self.position), ("position", self.position),
('bitrate', self.bitrate), ("bitrate", self.bitrate),
('video_quality_mode', self.video_quality_mode), ("video_quality_mode", self.video_quality_mode),
('user_limit', self.user_limit), ("user_limit", self.user_limit),
('category_id', self.category_id), ("category_id", self.category_id),
] ]
joined = ' '.join('%s=%r' % t for t in attrs) joined = " ".join("%s=%r" % t for t in attrs)
return f'<{self.__class__.__name__} {joined}>' return f"<{self.__class__.__name__} {joined}>"
def _update(self, guild: Guild, data: StageChannelPayload) -> None: def _update(self, guild: Guild, data: StageChannelPayload) -> None:
super()._update(guild, data) super()._update(guild, data)
self.topic = data.get('topic') self.topic = data.get("topic")
@property @property
def requesting_to_speak(self) -> List[Member]: def requesting_to_speak(self) -> List[Member]:
@ -1197,13 +1217,13 @@ class StageChannel(VocalGuildChannel):
The newly created stage instance. The newly created stage instance.
""" """
payload: Dict[str, Any] = {'channel_id': self.id, 'topic': topic} payload: Dict[str, Any] = {"channel_id": self.id, "topic": topic}
if privacy_level is not MISSING: if privacy_level is not MISSING:
if not isinstance(privacy_level, StagePrivacyLevel): if not isinstance(privacy_level, StagePrivacyLevel):
raise InvalidArgument('privacy_level field must be of type PrivacyLevel') raise InvalidArgument("privacy_level field must be of type PrivacyLevel")
payload['privacy_level'] = privacy_level.value payload["privacy_level"] = privacy_level.value
data = await self._state.http.create_stage_instance(**payload, reason=reason) data = await self._state.http.create_stage_instance(**payload, reason=reason)
return StageInstance(guild=self.guild, state=self._state, data=data) return StageInstance(guild=self.guild, state=self._state, data=data)
@ -1334,6 +1354,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
Returns the category's name. Returns the category's name.
.. describe:: int(x)
Returns the category's ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -1353,22 +1377,22 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead.
""" """
__slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id') __slots__ = ("name", "id", "guild", "nsfw", "_state", "position", "_overwrites", "category_id")
def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload):
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self._update(guild, data) self._update(guild, data)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<CategoryChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>' return f"<CategoryChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>"
def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: def _update(self, guild: Guild, data: CategoryChannelPayload) -> None:
self.guild: Guild = guild self.guild: Guild = guild
self.name: str = data['name'] self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.nsfw: bool = data.get('nsfw', False) self.nsfw: bool = data.get("nsfw", False)
self.position: int = data['position'] self.position: int = data["position"]
self._fill_overwrites(data) self._fill_overwrites(data)
@property @property
@ -1386,7 +1410,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
@utils.copy_doc(discord.abc.GuildChannel.clone) @utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel: async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel:
return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason)
@overload @overload
async def edit( async def edit(
@ -1455,7 +1479,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable):
@utils.copy_doc(discord.abc.GuildChannel.move) @utils.copy_doc(discord.abc.GuildChannel.move)
async def move(self, **kwargs): async def move(self, **kwargs):
kwargs.pop('category', None) kwargs.pop("category", None)
await super().move(**kwargs) await super().move(**kwargs)
@property @property
@ -1556,6 +1580,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
Returns the channel's name. Returns the channel's name.
.. describe:: int(x)
Returns the channel's ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -1578,30 +1606,30 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
""" """
__slots__ = ( __slots__ = (
'name', "name",
'id', "id",
'guild', "guild",
'_state', "_state",
'nsfw', "nsfw",
'category_id', "category_id",
'position', "position",
'_overwrites', "_overwrites",
) )
def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload): def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload):
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self._update(guild, data) self._update(guild, data)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<StoreChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>' return f"<StoreChannel id={self.id} name={self.name!r} position={self.position} nsfw={self.nsfw}>"
def _update(self, guild: Guild, data: StoreChannelPayload) -> None: def _update(self, guild: Guild, data: StoreChannelPayload) -> None:
self.guild: Guild = guild self.guild: Guild = guild
self.name: str = data['name'] self.name: str = data["name"]
self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id")
self.position: int = data['position'] self.position: int = data["position"]
self.nsfw: bool = data.get('nsfw', False) self.nsfw: bool = data.get("nsfw", False)
self._fill_overwrites(data) self._fill_overwrites(data)
@property @property
@ -1628,7 +1656,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
@utils.copy_doc(discord.abc.GuildChannel.clone) @utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel: async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel:
return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason)
@overload @overload
async def edit( async def edit(
@ -1704,7 +1732,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable):
return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore
DMC = TypeVar('DMC', bound='DMChannel') DMC = TypeVar("DMC", bound="DMChannel")
class DMChannel(discord.abc.Messageable, Hashable): class DMChannel(discord.abc.Messageable, Hashable):
@ -1728,6 +1756,10 @@ class DMChannel(discord.abc.Messageable, Hashable):
Returns a string representation of the channel Returns a string representation of the channel
.. describe:: int(x)
Returns the channel's ID.
Attributes Attributes
---------- ----------
recipient: Optional[:class:`User`] recipient: Optional[:class:`User`]
@ -1740,24 +1772,24 @@ class DMChannel(discord.abc.Messageable, Hashable):
The direct message channel ID. The direct message channel ID.
""" """
__slots__ = ('id', 'recipient', 'me', '_state') __slots__ = ("id", "recipient", "me", "_state")
def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload): def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload):
self._state: ConnectionState = state self._state: ConnectionState = state
self.recipient: Optional[User] = state.store_user(data['recipients'][0]) self.recipient: Optional[User] = state.store_user(data["recipients"][0])
self.me: ClientUser = me self.me: ClientUser = me
self.id: int = int(data['id']) self.id: int = int(data["id"])
async def _get_channel(self): async def _get_channel(self):
return self return self
def __str__(self) -> str: def __str__(self) -> str:
if self.recipient: if self.recipient:
return f'Direct Message with {self.recipient}' return f"Direct Message with {self.recipient}"
return 'Direct Message with Unknown User' return "Direct Message with Unknown User"
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<DMChannel id={self.id} recipient={self.recipient!r}>' return f"<DMChannel id={self.id} recipient={self.recipient!r}>"
@classmethod @classmethod
def _from_message(cls: Type[DMC], state: ConnectionState, channel_id: int) -> DMC: def _from_message(cls: Type[DMC], state: ConnectionState, channel_id: int) -> DMC:
@ -1854,6 +1886,10 @@ class GroupChannel(discord.abc.Messageable, Hashable):
Returns a string representation of the channel Returns a string representation of the channel
.. describe:: int(x)
Returns the channel's ID.
Attributes Attributes
---------- ----------
recipients: List[:class:`User`] recipients: List[:class:`User`]
@ -1872,19 +1908,19 @@ class GroupChannel(discord.abc.Messageable, Hashable):
The group channel's name if provided. The group channel's name if provided.
""" """
__slots__ = ('id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state') __slots__ = ("id", "recipients", "owner_id", "owner", "_icon", "name", "me", "_state")
def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload): def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload):
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.me: ClientUser = me self.me: ClientUser = me
self._update_group(data) self._update_group(data)
def _update_group(self, data: GroupChannelPayload) -> None: def _update_group(self, data: GroupChannelPayload) -> None:
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_id') self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_id")
self._icon: Optional[str] = data.get('icon') self._icon: Optional[str] = data.get("icon")
self.name: Optional[str] = data.get('name') self.name: Optional[str] = data.get("name")
self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])] self.recipients: List[User] = [self._state.store_user(u) for u in data.get("recipients", [])]
self.owner: Optional[BaseUser] self.owner: Optional[BaseUser]
if self.owner_id == self.me.id: if self.owner_id == self.me.id:
@ -1900,12 +1936,12 @@ class GroupChannel(discord.abc.Messageable, Hashable):
return self.name return self.name
if len(self.recipients) == 0: if len(self.recipients) == 0:
return 'Unnamed' return "Unnamed"
return ', '.join(map(lambda x: x.name, self.recipients)) return ", ".join(map(lambda x: x.name, self.recipients))
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<GroupChannel id={self.id} name={self.name!r}>' return f"<GroupChannel id={self.id} name={self.name!r}>"
@property @property
def type(self) -> ChannelType: def type(self) -> ChannelType:
@ -1917,7 +1953,7 @@ class GroupChannel(discord.abc.Messageable, Hashable):
"""Optional[:class:`Asset`]: Returns the channel's icon asset if available.""" """Optional[:class:`Asset`]: Returns the channel's icon asset if available."""
if self._icon is None: if self._icon is None:
return None return None
return Asset._from_icon(self._state, self.id, self._icon, path='channel') return Asset._from_icon(self._state, self.id, self._icon, path="channel")
@property @property
def created_at(self) -> datetime.datetime: def created_at(self) -> datetime.datetime:
@ -2000,6 +2036,10 @@ class PartialMessageable(discord.abc.Messageable, Hashable):
Returns the partial messageable's hash. Returns the partial messageable's hash.
.. describe:: int(x)
Returns the messageable's ID.
Attributes Attributes
----------- -----------
id: :class:`int` id: :class:`int`

View File

@ -29,7 +29,20 @@ import logging
import signal import signal
import sys import sys
import traceback import traceback
from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union from typing import (
Any,
Callable,
Coroutine,
Dict,
Generator,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
TypeVar,
Union,
)
import aiohttp import aiohttp
@ -69,46 +82,49 @@ if TYPE_CHECKING:
from .member import Member from .member import Member
from .voice_client import VoiceProtocol from .voice_client import VoiceProtocol
__all__ = ( __all__ = ("Client",)
'Client',
)
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]])
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None:
tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()}
if not tasks: if not tasks:
return return
_log.info('Cleaning up after %d tasks.', len(tasks)) _log.info("Cleaning up after %d tasks.", len(tasks))
for task in tasks: for task in tasks:
task.cancel() task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
_log.info('All tasks finished cancelling.') _log.info("All tasks finished cancelling.")
for task in tasks: for task in tasks:
if task.cancelled(): if task.cancelled():
continue continue
if task.exception() is not None: if task.exception() is not None:
loop.call_exception_handler({ loop.call_exception_handler(
'message': 'Unhandled exception during Client.run shutdown.', {
'exception': task.exception(), "message": "Unhandled exception during Client.run shutdown.",
'task': task "exception": task.exception(),
}) "task": task,
}
)
def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None:
try: try:
_cancel_tasks(loop) _cancel_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens()) loop.run_until_complete(loop.shutdown_asyncgens())
finally: finally:
_log.info('Closing the event loop.') _log.info("Closing the event loop.")
loop.close() loop.close()
class Client: class Client:
r"""Represents a client connection that connects to Discord. r"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API. This class is used to interact with the Discord WebSocket and API.
@ -142,7 +158,6 @@ class Client:
intents: :class:`Intents` intents: :class:`Intents`
The intents that you want to enable for the session. This is a way of The intents that you want to enable for the session. This is a way of
disabling and enabling certain gateway events from triggering and being sent. disabling and enabling certain gateway events from triggering and being sent.
If not given, defaults to a regularly constructed :class:`Intents` class.
.. versionadded:: 1.5 .. versionadded:: 1.5
member_cache_flags: :class:`MemberCacheFlags` member_cache_flags: :class:`MemberCacheFlags`
@ -200,34 +215,36 @@ class Client:
loop: :class:`asyncio.AbstractEventLoop` loop: :class:`asyncio.AbstractEventLoop`
The event loop that the client uses for asynchronous operations. The event loop that the client uses for asynchronous operations.
""" """
def __init__( def __init__(
self, self,
*, *,
intents: Intents,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
**options: Any, **options: Any,
): ):
options["intents"] = intents
# self.ws is set in the connect method # self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.shard_id: Optional[int] = options.get('shard_id') self.shard_id: Optional[int] = options.get("shard_id")
self.shard_count: Optional[int] = options.get('shard_count') self.shard_count: Optional[int] = options.get("shard_count")
connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None) connector: Optional[aiohttp.BaseConnector] = options.pop("connector", None)
proxy: Optional[str] = options.pop('proxy', None) proxy: Optional[str] = options.pop("proxy", None)
proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) proxy_auth: Optional[aiohttp.BasicAuth] = options.pop("proxy_auth", None)
unsync_clock: bool = options.pop('assume_unsync_clock', True) unsync_clock: bool = options.pop("assume_unsync_clock", True)
self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) self.http: HTTPClient = HTTPClient(
connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop
)
self._handlers: Dict[str, Callable] = { self._handlers: Dict[str, Callable] = {"ready": self._handle_ready}
'ready': self._handle_ready
}
self._hooks: Dict[str, Callable] = { self._hooks: Dict[str, Callable] = {"before_identify": self._call_before_identify_hook}
'before_identify': self._call_before_identify_hook
}
self._enable_debug_events: bool = options.pop('enable_debug_events', False) self._enable_debug_events: bool = options.pop("enable_debug_events", False)
self._connection: ConnectionState = self._get_state(**options) self._connection: ConnectionState = self._get_state(**options)
self._connection.shard_count = self.shard_count self._connection.shard_count = self.shard_count
self._closed: bool = False self._closed: bool = False
@ -245,8 +262,14 @@ class Client:
return self.ws return self.ws
def _get_state(self, **options: Any) -> ConnectionState: def _get_state(self, **options: Any) -> ConnectionState:
return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, return ConnectionState(
hooks=self._hooks, http=self.http, loop=self.loop, **options) dispatch=self.dispatch,
handlers=self._handlers,
hooks=self._hooks,
http=self.http,
loop=self.loop,
**options,
)
def _handle_ready(self) -> None: def _handle_ready(self) -> None:
self._ready.set() self._ready.set()
@ -258,7 +281,7 @@ class Client:
This could be referred to as the Discord WebSocket protocol latency. This could be referred to as the Discord WebSocket protocol latency.
""" """
ws = self.ws ws = self.ws
return float('nan') if not ws else ws.latency return float("nan") if not ws else ws.latency
def is_ws_ratelimited(self) -> bool: def is_ws_ratelimited(self) -> bool:
""":class:`bool`: Whether the websocket is currently rate limited. """:class:`bool`: Whether the websocket is currently rate limited.
@ -346,7 +369,9 @@ class Client:
""":class:`bool`: Specifies if the client's internal cache is ready for use.""" """:class:`bool`: Specifies if the client's internal cache is ready for use."""
return self._ready.is_set() return self._ready.is_set()
async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: async def _run_event(
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> None:
try: try:
await coro(*args, **kwargs) await coro(*args, **kwargs)
except asyncio.CancelledError: except asyncio.CancelledError:
@ -357,14 +382,16 @@ class Client:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: def _schedule_event(
self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any
) -> asyncio.Task:
wrapped = self._run_event(coro, event_name, *args, **kwargs) wrapped = self._run_event(coro, event_name, *args, **kwargs)
# Schedules the task # Schedules the task
return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') return asyncio.create_task(wrapped, name=f"discord.py: {event_name}")
def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
_log.debug('Dispatching event %s', event) _log.debug("Dispatching event %s", event)
method = 'on_' + event method = "on_" + event
listeners = self._listeners.get(event) listeners = self._listeners.get(event)
if listeners: if listeners:
@ -411,7 +438,7 @@ class Client:
overridden to have a different implementation. overridden to have a different implementation.
Check :func:`~discord.on_error` for more details. Check :func:`~discord.on_error` for more details.
""" """
print(f'Ignoring exception in {event_method}', file=sys.stderr) print(f"Ignoring exception in {event_method}", file=sys.stderr)
traceback.print_exc() traceback.print_exc()
# hooks # hooks
@ -468,7 +495,7 @@ class Client:
passing status code. passing status code.
""" """
_log.info('logging in using static token') _log.info("logging in using static token")
data = await self.http.static_login(token.strip()) data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data) self._connection.user = ClientUser(state=self._connection, data=data)
@ -500,29 +527,31 @@ class Client:
backoff = ExponentialBackoff() backoff = ExponentialBackoff()
ws_params = { ws_params = {
'initial': True, "initial": True,
'shard_id': self.shard_id, "shard_id": self.shard_id,
} }
while not self.is_closed(): while not self.is_closed():
try: try:
coro = DiscordWebSocket.from_client(self, **ws_params) coro = DiscordWebSocket.from_client(self, **ws_params)
self.ws = await asyncio.wait_for(coro, timeout=60.0) self.ws = await asyncio.wait_for(coro, timeout=60.0)
ws_params['initial'] = False ws_params["initial"] = False
while True: while True:
await self.ws.poll_event() await self.ws.poll_event()
except ReconnectWebSocket as e: except ReconnectWebSocket as e:
_log.info('Got a request to %s the websocket.', e.op) _log.info("Got a request to %s the websocket.", e.op)
self.dispatch('disconnect') self.dispatch("disconnect")
ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id)
continue continue
except (OSError, except (
HTTPException, OSError,
GatewayNotFound, HTTPException,
ConnectionClosed, GatewayNotFound,
aiohttp.ClientError, ConnectionClosed,
asyncio.TimeoutError) as exc: aiohttp.ClientError,
asyncio.TimeoutError,
) as exc:
self.dispatch('disconnect') self.dispatch("disconnect")
if not reconnect: if not reconnect:
await self.close() await self.close()
if isinstance(exc, ConnectionClosed) and exc.code == 1000: if isinstance(exc, ConnectionClosed) and exc.code == 1000:
@ -595,7 +624,7 @@ class Client:
async def start(self, token: str, *, reconnect: bool = True) -> None: async def start(self, token: str, *, reconnect: bool = True) -> None:
"""|coro| """|coro|
A shorthand coroutine for :meth:`login` + :meth:`connect`. A shorthand coroutine for :meth:`login` + :meth:`setup` + :meth:`connect`.
Raises Raises
------- -------
@ -603,8 +632,21 @@ class Client:
An unexpected keyword argument was received. An unexpected keyword argument was received.
""" """
await self.login(token) await self.login(token)
await self.setup()
await self.connect(reconnect=reconnect) await self.connect(reconnect=reconnect)
async def setup(self) -> Any:
"""|coro|
A coroutine to be called to setup the bot, by default this is blank.
To perform asynchronous setup after the bot is logged in but before
it has connected to the Websocket, overwrite this coroutine.
.. versionadded:: 2.0
"""
pass
def run(self, *args: Any, **kwargs: Any) -> None: def run(self, *args: Any, **kwargs: Any) -> None:
"""A blocking call that abstracts away the event loop """A blocking call that abstracts away the event loop
initialisation from you. initialisation from you.
@ -652,10 +694,10 @@ class Client:
try: try:
loop.run_forever() loop.run_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
_log.info('Received signal to terminate bot and event loop.') _log.info("Received signal to terminate bot and event loop.")
finally: finally:
future.remove_done_callback(stop_loop_on_completion) future.remove_done_callback(stop_loop_on_completion)
_log.info('Cleaning up tasks.') _log.info("Cleaning up tasks.")
_cleanup_loop(loop) _cleanup_loop(loop)
if not future.cancelled(): if not future.cancelled():
@ -684,16 +726,16 @@ class Client:
self._connection._activity = None self._connection._activity = None
elif isinstance(value, BaseActivity): elif isinstance(value, BaseActivity):
# ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any] # ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any]
self._connection._activity = value.to_dict() # type: ignore self._connection._activity = value.to_dict() # type: ignore
else: else:
raise TypeError('activity must derive from BaseActivity.') raise TypeError("activity must derive from BaseActivity.")
@property @property
def status(self): def status(self):
""":class:`.Status`: """:class:`.Status`:
The status being used upon logging on to Discord. The status being used upon logging on to Discord.
.. versionadded: 2.0 .. versionadded:: 2.0
""" """
if self._connection._status in set(state.value for state in Status): if self._connection._status in set(state.value for state in Status):
return Status(self._connection._status) return Status(self._connection._status)
@ -702,11 +744,11 @@ class Client:
@status.setter @status.setter
def status(self, value): def status(self, value):
if value is Status.offline: if value is Status.offline:
self._connection._status = 'invisible' self._connection._status = "invisible"
elif isinstance(value, Status): elif isinstance(value, Status):
self._connection._status = str(value) self._connection._status = str(value)
else: else:
raise TypeError('status must derive from Status.') raise TypeError("status must derive from Status.")
@property @property
def allowed_mentions(self) -> Optional[AllowedMentions]: def allowed_mentions(self) -> Optional[AllowedMentions]:
@ -721,7 +763,7 @@ class Client:
if value is None or isinstance(value, AllowedMentions): if value is None or isinstance(value, AllowedMentions):
self._connection.allowed_mentions = value self._connection.allowed_mentions = value
else: else:
raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}') raise TypeError(f"allowed_mentions must be AllowedMentions not {value.__class__!r}")
@property @property
def intents(self) -> Intents: def intents(self) -> Intents:
@ -827,6 +869,38 @@ class Client:
""" """
return self._connection.get_user(id) return self._connection.get_user(id)
async def try_user(self, id: int, /) -> Optional[User]:
"""|coro|
Returns a user with the given ID. If not from cache, the user will be requested from the API.
You do not have to share any guilds with the user to get this information from the API,
however many operations do require that you do.
.. note::
This method is an API call. If you have :attr:`discord.Intents.members` and member cache enabled, consider :meth:`get_user` instead.
.. versionadded:: 2.0
Parameters
-----------
id: :class:`int`
The ID to search for.
Returns
--------
Optional[:class:`~discord.User`]
The user or ``None`` if not found.
"""
maybe_user = self.get_user(id)
if maybe_user is not None:
return maybe_user
try:
return await self.fetch_user(id)
except NotFound:
return None
def get_emoji(self, id: int, /) -> Optional[Emoji]: def get_emoji(self, id: int, /) -> Optional[Emoji]:
"""Returns an emoji with the given ID. """Returns an emoji with the given ID.
@ -999,8 +1073,10 @@ class Client:
future = self.loop.create_future() future = self.loop.create_future()
if check is None: if check is None:
def _check(*args): def _check(*args):
return True return True
check = _check check = _check
ev = event.lower() ev = event.lower()
@ -1038,10 +1114,10 @@ class Client:
""" """
if not asyncio.iscoroutinefunction(coro): if not asyncio.iscoroutinefunction(coro):
raise TypeError('event registered must be a coroutine function') raise TypeError("event registered must be a coroutine function")
setattr(self, coro.__name__, coro) setattr(self, coro.__name__, coro)
_log.debug('%s has successfully been registered as an event', coro.__name__) _log.debug("%s has successfully been registered as an event", coro.__name__)
return coro return coro
async def change_presence( async def change_presence(
@ -1080,10 +1156,10 @@ class Client:
""" """
if status is None: if status is None:
status_str = 'online' status_str = "online"
status = Status.online status = Status.online
elif status is Status.offline: elif status is Status.offline:
status_str = 'invisible' status_str = "invisible"
status = Status.offline status = Status.offline
else: else:
status_str = str(status) status_str = str(status)
@ -1105,11 +1181,7 @@ class Client:
# Guild stuff # Guild stuff
def fetch_guilds( def fetch_guilds(
self, self, *, limit: Optional[int] = 100, before: SnowflakeTime = None, after: SnowflakeTime = None
*,
limit: Optional[int] = 100,
before: SnowflakeTime = None,
after: SnowflakeTime = None
) -> GuildIterator: ) -> GuildIterator:
"""Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds.
@ -1189,7 +1261,7 @@ class Client:
""" """
code = utils.resolve_template(code) code = utils.resolve_template(code)
data = await self.http.get_template(code) data = await self.http.get_template(code)
return Template(data=data, state=self._connection) # type: ignore return Template(data=data, state=self._connection) # type: ignore
async def fetch_guild(self, guild_id: int, /) -> Guild: async def fetch_guild(self, guild_id: int, /) -> Guild:
"""|coro| """|coro|
@ -1305,12 +1377,14 @@ class Client:
The stage instance from the stage channel ID. The stage instance from the stage channel ID.
""" """
data = await self.http.get_stage_instance(channel_id) data = await self.http.get_stage_instance(channel_id)
guild = self.get_guild(int(data['guild_id'])) guild = self.get_guild(int(data["guild_id"]))
return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore
# Invite management # Invite management
async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite: async def fetch_invite(
self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True
) -> Invite:
"""|coro| """|coro|
Gets an :class:`.Invite` from a discord.gg URL or ID. Gets an :class:`.Invite` from a discord.gg URL or ID.
@ -1426,8 +1500,8 @@ class Client:
The bot's application information. The bot's application information.
""" """
data = await self.http.application_info() data = await self.http.application_info()
if 'rpc_origins' not in data: if "rpc_origins" not in data:
data['rpc_origins'] = None data["rpc_origins"] = None
return AppInfo(self._connection, data) return AppInfo(self._connection, data)
async def fetch_user(self, user_id: int, /) -> User: async def fetch_user(self, user_id: int, /) -> User:
@ -1490,19 +1564,19 @@ class Client:
""" """
data = await self.http.get_channel(channel_id) data = await self.http.get_channel(channel_id)
factory, ch_type = _threaded_channel_factory(data['type']) factory, ch_type = _threaded_channel_factory(data["type"])
if factory is None: if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data))
if ch_type in (ChannelType.group, ChannelType.private): if ch_type in (ChannelType.group, ChannelType.private):
# the factory will be a DMChannel or GroupChannel here # the factory will be a DMChannel or GroupChannel here
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore channel = factory(me=self.user, data=data, state=self._connection) # type: ignore
else: else:
# the factory can't be a DMChannel or GroupChannel here # the factory can't be a DMChannel or GroupChannel here
guild_id = int(data['guild_id']) # type: ignore guild_id = int(data["guild_id"]) # type: ignore
guild = self.get_guild(guild_id) or Object(id=guild_id) guild = self.get_guild(guild_id) or Object(id=guild_id)
# GuildChannels expect a Guild, we may be passing an Object # GuildChannels expect a Guild, we may be passing an Object
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
return channel return channel
@ -1548,8 +1622,8 @@ class Client:
The sticker you requested. The sticker you requested.
""" """
data = await self.http.get_sticker(sticker_id) data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data['type']) # type: ignore cls, _ = _sticker_factory(data["type"]) # type: ignore
return cls(state=self._connection, data=data) # type: ignore return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]: async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro| """|coro|
@ -1569,7 +1643,7 @@ class Client:
All available premium sticker packs. All available premium sticker packs.
""" """
data = await self.http.list_premium_sticker_packs() data = await self.http.list_premium_sticker_packs()
return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']] return [StickerPack(state=self._connection, data=pack) for pack in data["sticker_packs"]]
async def create_dm(self, user: Snowflake) -> DMChannel: async def create_dm(self, user: Snowflake) -> DMChannel:
"""|coro| """|coro|
@ -1626,10 +1700,10 @@ class Client:
""" """
if not isinstance(view, View): if not isinstance(view, View):
raise TypeError(f'expected an instance of View not {view.__class__!r}') raise TypeError(f"expected an instance of View not {view.__class__!r}")
if not view.is_persistent(): if not view.is_persistent():
raise ValueError('View is not persistent. Items need to have a custom_id set and View must have no timeout') raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout")
self._connection.store_view(view, message_id) self._connection.store_view(view, message_id)

View File

@ -35,11 +35,11 @@ from typing import (
) )
__all__ = ( __all__ = (
'Colour', "Colour",
'Color', "Color",
) )
CT = TypeVar('CT', bound='Colour') CT = TypeVar("CT", bound="Colour")
class Colour: class Colour:
@ -76,16 +76,16 @@ class Colour:
The raw integer colour value. The raw integer colour value.
""" """
__slots__ = ('value',) __slots__ = ("value",)
def __init__(self, value: int): def __init__(self, value: int):
if not isinstance(value, int): if not isinstance(value, int):
raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.') raise TypeError(f"Expected int parameter, received {value.__class__.__name__} instead.")
self.value: int = value self.value: int = value
def _get_byte(self, byte: int) -> int: def _get_byte(self, byte: int) -> int:
return (self.value >> (8 * byte)) & 0xff return (self.value >> (8 * byte)) & 0xFF
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return isinstance(other, Colour) and self.value == other.value return isinstance(other, Colour) and self.value == other.value
@ -94,13 +94,13 @@ class Colour:
return not self.__eq__(other) return not self.__eq__(other)
def __str__(self) -> str: def __str__(self) -> str:
return f'#{self.value:0>6x}' return f"#{self.value:0>6x}"
def __int__(self) -> int: def __int__(self) -> int:
return self.value return self.value
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Colour value={self.value}>' return f"<Colour value={self.value}>"
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(self.value) return hash(self.value)
@ -164,12 +164,12 @@ class Colour:
@classmethod @classmethod
def teal(cls: Type[CT]) -> CT: def teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``.""" """A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``."""
return cls(0x1abc9c) return cls(0x1ABC9C)
@classmethod @classmethod
def dark_teal(cls: Type[CT]) -> CT: def dark_teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x11806a``.""" """A factory method that returns a :class:`Colour` with a value of ``0x11806a``."""
return cls(0x11806a) return cls(0x11806A)
@classmethod @classmethod
def brand_green(cls: Type[CT]) -> CT: def brand_green(cls: Type[CT]) -> CT:
@ -182,17 +182,17 @@ class Colour:
@classmethod @classmethod
def green(cls: Type[CT]) -> CT: def green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``.""" """A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``."""
return cls(0x2ecc71) return cls(0x2ECC71)
@classmethod @classmethod
def dark_green(cls: Type[CT]) -> CT: def dark_green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``.""" """A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``."""
return cls(0x1f8b4c) return cls(0x1F8B4C)
@classmethod @classmethod
def blue(cls: Type[CT]) -> CT: def blue(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x3498db``.""" """A factory method that returns a :class:`Colour` with a value of ``0x3498db``."""
return cls(0x3498db) return cls(0x3498DB)
@classmethod @classmethod
def dark_blue(cls: Type[CT]) -> CT: def dark_blue(cls: Type[CT]) -> CT:
@ -202,42 +202,42 @@ class Colour:
@classmethod @classmethod
def purple(cls: Type[CT]) -> CT: def purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``.""" """A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``."""
return cls(0x9b59b6) return cls(0x9B59B6)
@classmethod @classmethod
def dark_purple(cls: Type[CT]) -> CT: def dark_purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x71368a``.""" """A factory method that returns a :class:`Colour` with a value of ``0x71368a``."""
return cls(0x71368a) return cls(0x71368A)
@classmethod @classmethod
def magenta(cls: Type[CT]) -> CT: def magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe91e63``.""" """A factory method that returns a :class:`Colour` with a value of ``0xe91e63``."""
return cls(0xe91e63) return cls(0xE91E63)
@classmethod @classmethod
def dark_magenta(cls: Type[CT]) -> CT: def dark_magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xad1457``.""" """A factory method that returns a :class:`Colour` with a value of ``0xad1457``."""
return cls(0xad1457) return cls(0xAD1457)
@classmethod @classmethod
def gold(cls: Type[CT]) -> CT: def gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``.""" """A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``."""
return cls(0xf1c40f) return cls(0xF1C40F)
@classmethod @classmethod
def dark_gold(cls: Type[CT]) -> CT: def dark_gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``.""" """A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``."""
return cls(0xc27c0e) return cls(0xC27C0E)
@classmethod @classmethod
def orange(cls: Type[CT]) -> CT: def orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe67e22``.""" """A factory method that returns a :class:`Colour` with a value of ``0xe67e22``."""
return cls(0xe67e22) return cls(0xE67E22)
@classmethod @classmethod
def dark_orange(cls: Type[CT]) -> CT: def dark_orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xa84300``.""" """A factory method that returns a :class:`Colour` with a value of ``0xa84300``."""
return cls(0xa84300) return cls(0xA84300)
@classmethod @classmethod
def brand_red(cls: Type[CT]) -> CT: def brand_red(cls: Type[CT]) -> CT:
@ -250,45 +250,52 @@ class Colour:
@classmethod @classmethod
def red(cls: Type[CT]) -> CT: def red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
return cls(0xe74c3c) return cls(0xE74C3C)
@classmethod
def nitro_booster(cls):
"""A factory method that returns a :class:`Colour` with a value of ``0xf47fff``.
.. versionadded:: 2.0"""
return cls(0xF47FFF)
@classmethod @classmethod
def dark_red(cls: Type[CT]) -> CT: def dark_red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x992d22``.""" """A factory method that returns a :class:`Colour` with a value of ``0x992d22``."""
return cls(0x992d22) return cls(0x992D22)
@classmethod @classmethod
def lighter_grey(cls: Type[CT]) -> CT: def lighter_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
return cls(0x95a5a6) return cls(0x95A5A6)
lighter_gray = lighter_grey lighter_gray = lighter_grey
@classmethod @classmethod
def dark_grey(cls: Type[CT]) -> CT: def dark_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
return cls(0x607d8b) return cls(0x607D8B)
dark_gray = dark_grey dark_gray = dark_grey
@classmethod @classmethod
def light_grey(cls: Type[CT]) -> CT: def light_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
return cls(0x979c9f) return cls(0x979C9F)
light_gray = light_grey light_gray = light_grey
@classmethod @classmethod
def darker_grey(cls: Type[CT]) -> CT: def darker_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
return cls(0x546e7a) return cls(0x546E7A)
darker_gray = darker_grey darker_gray = darker_grey
@classmethod @classmethod
def og_blurple(cls: Type[CT]) -> CT: def og_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x7289da``.""" """A factory method that returns a :class:`Colour` with a value of ``0x7289da``."""
return cls(0x7289da) return cls(0x7289DA)
@classmethod @classmethod
def blurple(cls: Type[CT]) -> CT: def blurple(cls: Type[CT]) -> CT:
@ -298,7 +305,7 @@ class Colour:
@classmethod @classmethod
def greyple(cls: Type[CT]) -> CT: def greyple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x99aab5``.""" """A factory method that returns a :class:`Colour` with a value of ``0x99aab5``."""
return cls(0x99aab5) return cls(0x99AAB5)
@classmethod @classmethod
def dark_theme(cls: Type[CT]) -> CT: def dark_theme(cls: Type[CT]) -> CT:
@ -325,5 +332,14 @@ class Colour:
""" """
return cls(0xFEE75C) return cls(0xFEE75C)
@classmethod
def dark_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.
This is the original Dark Blurple branding.
.. versionadded:: 2.0
"""
return cls(0x4E5D94)
Color = Colour Color = Colour

View File

@ -41,14 +41,14 @@ if TYPE_CHECKING:
__all__ = ( __all__ = (
'Component', "Component",
'ActionRow', "ActionRow",
'Button', "Button",
'SelectMenu', "SelectMenu",
'SelectOption', "SelectOption",
) )
C = TypeVar('C', bound='Component') C = TypeVar("C", bound="Component")
class Component: class Component:
@ -70,14 +70,14 @@ class Component:
The type of component. The type of component.
""" """
__slots__: Tuple[str, ...] = ('type',) __slots__: Tuple[str, ...] = ("type",)
__repr_info__: ClassVar[Tuple[str, ...]] __repr_info__: ClassVar[Tuple[str, ...]]
type: ComponentType type: ComponentType
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__)
return f'<{self.__class__.__name__} {attrs}>' return f"<{self.__class__.__name__} {attrs}>"
@classmethod @classmethod
def _raw_construct(cls: Type[C], **kwargs) -> C: def _raw_construct(cls: Type[C], **kwargs) -> C:
@ -112,18 +112,18 @@ class ActionRow(Component):
The children components that this holds, if any. The children components that this holds, if any.
""" """
__slots__: Tuple[str, ...] = ('children',) __slots__: Tuple[str, ...] = ("children",)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__ __repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ComponentPayload): def __init__(self, data: ComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type']) self.type: ComponentType = try_enum(ComponentType, data["type"])
self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])] self.children: List[Component] = [_component_factory(d) for d in data.get("components", [])]
def to_dict(self) -> ActionRowPayload: def to_dict(self) -> ActionRowPayload:
return { return {
'type': int(self.type), "type": int(self.type),
'components': [child.to_dict() for child in self.children], "components": [child.to_dict() for child in self.children],
} # type: ignore } # type: ignore
@ -157,44 +157,44 @@ class Button(Component):
""" """
__slots__: Tuple[str, ...] = ( __slots__: Tuple[str, ...] = (
'style', "style",
'custom_id', "custom_id",
'url', "url",
'disabled', "disabled",
'label', "label",
'emoji', "emoji",
) )
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__ __repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ButtonComponentPayload): def __init__(self, data: ButtonComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type']) self.type: ComponentType = try_enum(ComponentType, data["type"])
self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) self.style: ButtonStyle = try_enum(ButtonStyle, data["style"])
self.custom_id: Optional[str] = data.get('custom_id') self.custom_id: Optional[str] = data.get("custom_id")
self.url: Optional[str] = data.get('url') self.url: Optional[str] = data.get("url")
self.disabled: bool = data.get('disabled', False) self.disabled: bool = data.get("disabled", False)
self.label: Optional[str] = data.get('label') self.label: Optional[str] = data.get("label")
self.emoji: Optional[PartialEmoji] self.emoji: Optional[PartialEmoji]
try: try:
self.emoji = PartialEmoji.from_dict(data['emoji']) self.emoji = PartialEmoji.from_dict(data["emoji"])
except KeyError: except KeyError:
self.emoji = None self.emoji = None
def to_dict(self) -> ButtonComponentPayload: def to_dict(self) -> ButtonComponentPayload:
payload = { payload = {
'type': 2, "type": 2,
'style': int(self.style), "style": int(self.style),
'label': self.label, "label": self.label,
'disabled': self.disabled, "disabled": self.disabled,
} }
if self.custom_id: if self.custom_id:
payload['custom_id'] = self.custom_id payload["custom_id"] = self.custom_id
if self.url: if self.url:
payload['url'] = self.url payload["url"] = self.url
if self.emoji: if self.emoji:
payload['emoji'] = self.emoji.to_dict() payload["emoji"] = self.emoji.to_dict()
return payload # type: ignore return payload # type: ignore
@ -231,37 +231,37 @@ class SelectMenu(Component):
""" """
__slots__: Tuple[str, ...] = ( __slots__: Tuple[str, ...] = (
'custom_id', "custom_id",
'placeholder', "placeholder",
'min_values', "min_values",
'max_values', "max_values",
'options', "options",
'disabled', "disabled",
) )
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__ __repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: SelectMenuPayload): def __init__(self, data: SelectMenuPayload):
self.type = ComponentType.select self.type = ComponentType.select
self.custom_id: str = data['custom_id'] self.custom_id: str = data["custom_id"]
self.placeholder: Optional[str] = data.get('placeholder') self.placeholder: Optional[str] = data.get("placeholder")
self.min_values: int = data.get('min_values', 1) self.min_values: int = data.get("min_values", 1)
self.max_values: int = data.get('max_values', 1) self.max_values: int = data.get("max_values", 1)
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])] self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get("options", [])]
self.disabled: bool = data.get('disabled', False) self.disabled: bool = data.get("disabled", False)
def to_dict(self) -> SelectMenuPayload: def to_dict(self) -> SelectMenuPayload:
payload: SelectMenuPayload = { payload: SelectMenuPayload = {
'type': self.type.value, "type": self.type.value,
'custom_id': self.custom_id, "custom_id": self.custom_id,
'min_values': self.min_values, "min_values": self.min_values,
'max_values': self.max_values, "max_values": self.max_values,
'options': [op.to_dict() for op in self.options], "options": [op.to_dict() for op in self.options],
'disabled': self.disabled, "disabled": self.disabled,
} }
if self.placeholder: if self.placeholder:
payload['placeholder'] = self.placeholder payload["placeholder"] = self.placeholder
return payload return payload
@ -292,11 +292,11 @@ class SelectOption:
""" """
__slots__: Tuple[str, ...] = ( __slots__: Tuple[str, ...] = (
'label', "label",
'value', "value",
'description', "description",
'emoji', "emoji",
'default', "default",
) )
def __init__( def __init__(
@ -318,60 +318,60 @@ class SelectOption:
elif isinstance(emoji, _EmojiTag): elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial() emoji = emoji._to_partial()
else: else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') raise TypeError(f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}")
self.emoji = emoji self.emoji = emoji
self.default = default self.default = default
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} ' f"<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} "
f'emoji={self.emoji!r} default={self.default!r}>' f"emoji={self.emoji!r} default={self.default!r}>"
) )
def __str__(self) -> str: def __str__(self) -> str:
if self.emoji: if self.emoji:
base = f'{self.emoji} {self.label}' base = f"{self.emoji} {self.label}"
else: else:
base = self.label base = self.label
if self.description: if self.description:
return f'{base}\n{self.description}' return f"{base}\n{self.description}"
return base return base
@classmethod @classmethod
def from_dict(cls, data: SelectOptionPayload) -> SelectOption: def from_dict(cls, data: SelectOptionPayload) -> SelectOption:
try: try:
emoji = PartialEmoji.from_dict(data['emoji']) emoji = PartialEmoji.from_dict(data["emoji"])
except KeyError: except KeyError:
emoji = None emoji = None
return cls( return cls(
label=data['label'], label=data["label"],
value=data['value'], value=data["value"],
description=data.get('description'), description=data.get("description"),
emoji=emoji, emoji=emoji,
default=data.get('default', False), default=data.get("default", False),
) )
def to_dict(self) -> SelectOptionPayload: def to_dict(self) -> SelectOptionPayload:
payload: SelectOptionPayload = { payload: SelectOptionPayload = {
'label': self.label, "label": self.label,
'value': self.value, "value": self.value,
'default': self.default, "default": self.default,
} }
if self.emoji: if self.emoji:
payload['emoji'] = self.emoji.to_dict() # type: ignore payload["emoji"] = self.emoji.to_dict() # type: ignore
if self.description: if self.description:
payload['description'] = self.description payload["description"] = self.description
return payload return payload
def _component_factory(data: ComponentPayload) -> Component: def _component_factory(data: ComponentPayload) -> Component:
component_type = data['type'] component_type = data["type"]
if component_type == 1: if component_type == 1:
return ActionRow(data) return ActionRow(data)
elif component_type == 2: elif component_type == 2:

View File

@ -32,11 +32,10 @@ if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
TypingT = TypeVar('TypingT', bound='Typing') TypingT = TypeVar("TypingT", bound="Typing")
__all__ = ("Typing",)
__all__ = (
'Typing',
)
def _typing_done_callback(fut: asyncio.Future) -> None: def _typing_done_callback(fut: asyncio.Future) -> None:
# just retrieve any exception and call it a day # just retrieve any exception and call it a day
@ -45,6 +44,7 @@ def _typing_done_callback(fut: asyncio.Future) -> None:
except (asyncio.CancelledError, Exception): except (asyncio.CancelledError, Exception):
pass pass
class Typing: class Typing:
def __init__(self, messageable: Messageable) -> None: def __init__(self, messageable: Messageable) -> None:
self.loop: asyncio.AbstractEventLoop = messageable._state.loop self.loop: asyncio.AbstractEventLoop = messageable._state.loop
@ -67,7 +67,8 @@ class Typing:
self.task.add_done_callback(_typing_done_callback) self.task.add_done_callback(_typing_done_callback)
return self return self
def __exit__(self, def __exit__(
self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],
@ -79,7 +80,8 @@ class Typing:
await channel._state.http.send_typing(channel.id) await channel._state.http.send_typing(channel.id)
return self.__enter__() return self.__enter__()
async def __aexit__(self, async def __aexit__(
self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[TracebackType], traceback: Optional[TracebackType],

View File

@ -30,9 +30,7 @@ from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Typ
from . import utils from . import utils
from .colour import Colour from .colour import Colour
__all__ = ( __all__ = ("Embed",)
'Embed',
)
class _EmptyEmbed: class _EmptyEmbed:
@ -40,7 +38,7 @@ class _EmptyEmbed:
return False return False
def __repr__(self) -> str: def __repr__(self) -> str:
return 'Embed.Empty' return "Embed.Empty"
def __len__(self) -> int: def __len__(self) -> int:
return 0 return 0
@ -57,19 +55,19 @@ class EmbedProxy:
return len(self.__dict__) return len(self.__dict__)
def __repr__(self) -> str: def __repr__(self) -> str:
inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_'))) inner = ", ".join((f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")))
return f'EmbedProxy({inner})' return f"EmbedProxy({inner})"
def __getattr__(self, attr: str) -> _EmptyEmbed: def __getattr__(self, attr: str) -> _EmptyEmbed:
return EmptyEmbed return EmptyEmbed
E = TypeVar('E', bound='Embed') E = TypeVar("E", bound="Embed")
if TYPE_CHECKING: if TYPE_CHECKING:
from discord.types.embed import Embed as EmbedData, EmbedType from discord.types.embed import Embed as EmbedData, EmbedType
T = TypeVar('T') T = TypeVar("T")
MaybeEmpty = Union[T, _EmptyEmbed] MaybeEmpty = Union[T, _EmptyEmbed]
class _EmbedFooterProxy(Protocol): class _EmbedFooterProxy(Protocol):
@ -157,19 +155,19 @@ class Embed:
""" """
__slots__ = ( __slots__ = (
'title', "title",
'url', "url",
'type', "type",
'_timestamp', "_timestamp",
'_colour', "_colour",
'_footer', "_footer",
'_image', "_image",
'_thumbnail', "_thumbnail",
'_video', "_video",
'_provider', "_provider",
'_author', "_author",
'_fields', "_fields",
'description', "description",
) )
Empty: Final = EmptyEmbed Empty: Final = EmptyEmbed
@ -180,7 +178,7 @@ class Embed:
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed, title: MaybeEmpty[Any] = EmptyEmbed,
type: EmbedType = 'rich', type: EmbedType = "rich",
url: MaybeEmpty[Any] = EmptyEmbed, url: MaybeEmpty[Any] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed, description: MaybeEmpty[Any] = EmptyEmbed,
timestamp: datetime.datetime = None, timestamp: datetime.datetime = None,
@ -225,10 +223,10 @@ class Embed:
# fill in the basic fields # fill in the basic fields
self.title = data.get('title', EmptyEmbed) self.title = data.get("title", EmptyEmbed)
self.type = data.get('type', EmptyEmbed) self.type = data.get("type", EmptyEmbed)
self.description = data.get('description', EmptyEmbed) self.description = data.get("description", EmptyEmbed)
self.url = data.get('url', EmptyEmbed) self.url = data.get("url", EmptyEmbed)
if self.title is not EmptyEmbed: if self.title is not EmptyEmbed:
self.title = str(self.title) self.title = str(self.title)
@ -242,22 +240,22 @@ class Embed:
# try to fill in the more rich fields # try to fill in the more rich fields
try: try:
self._colour = Colour(value=data['color']) self._colour = Colour(value=data["color"])
except KeyError: except KeyError:
pass pass
try: try:
self._timestamp = utils.parse_time(data['timestamp']) self._timestamp = utils.parse_time(data["timestamp"])
except KeyError: except KeyError:
pass pass
for attr in ('thumbnail', 'video', 'provider', 'author', 'fields', 'image', 'footer'): for attr in ("thumbnail", "video", "provider", "author", "fields", "image", "footer"):
try: try:
value = data[attr] value = data[attr]
except KeyError: except KeyError:
continue continue
else: else:
setattr(self, '_' + attr, value) setattr(self, "_" + attr, value)
return self return self
@ -267,11 +265,11 @@ class Embed:
def __len__(self) -> int: def __len__(self) -> int:
total = len(self.title) + len(self.description) total = len(self.title) + len(self.description)
for field in getattr(self, '_fields', []): for field in getattr(self, "_fields", []):
total += len(field['name']) + len(field['value']) total += len(field["name"]) + len(field["value"])
try: try:
footer_text = self._footer['text'] footer_text = self._footer["text"]
except (AttributeError, KeyError): except (AttributeError, KeyError):
pass pass
else: else:
@ -282,7 +280,7 @@ class Embed:
except AttributeError: except AttributeError:
pass pass
else: else:
total += len(author['name']) total += len(author["name"])
return total return total
@ -306,7 +304,7 @@ class Embed:
@property @property
def colour(self) -> MaybeEmpty[Colour]: def colour(self) -> MaybeEmpty[Colour]:
return getattr(self, '_colour', EmptyEmbed) return getattr(self, "_colour", EmptyEmbed)
@colour.setter @colour.setter
def colour(self, value: Union[int, Colour, _EmptyEmbed]): # type: ignore def colour(self, value: Union[int, Colour, _EmptyEmbed]): # type: ignore
@ -315,13 +313,15 @@ class Embed:
elif isinstance(value, int): elif isinstance(value, int):
self._colour = Colour(value=value) self._colour = Colour(value=value)
else: else:
raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.') raise TypeError(
f"Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead."
)
color = colour color = colour
@property @property
def timestamp(self) -> MaybeEmpty[datetime.datetime]: def timestamp(self) -> MaybeEmpty[datetime.datetime]:
return getattr(self, '_timestamp', EmptyEmbed) return getattr(self, "_timestamp", EmptyEmbed)
@timestamp.setter @timestamp.setter
def timestamp(self, value: MaybeEmpty[datetime.datetime]): def timestamp(self, value: MaybeEmpty[datetime.datetime]):
@ -342,7 +342,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then :attr:`Empty` is returned.
""" """
return EmbedProxy(getattr(self, '_footer', {})) # type: ignore return EmbedProxy(getattr(self, "_footer", {})) # type: ignore
def set_footer(self: E, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E: def set_footer(self: E, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E:
"""Sets the footer for the embed content. """Sets the footer for the embed content.
@ -360,10 +360,10 @@ class Embed:
self._footer = {} self._footer = {}
if text is not EmptyEmbed: if text is not EmptyEmbed:
self._footer['text'] = str(text) self._footer["text"] = str(text)
if icon_url is not EmptyEmbed: if icon_url is not EmptyEmbed:
self._footer['icon_url'] = str(icon_url) self._footer["icon_url"] = str(icon_url)
return self return self
@ -395,7 +395,21 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then :attr:`Empty` is returned.
""" """
return EmbedProxy(getattr(self, '_image', {})) # type: ignore return EmbedProxy(getattr(self, "_image", {})) # type: ignore
@image.setter
def image(self, url: Any):
if url is EmptyEmbed:
del self.image
else:
self._image = {"url": str(url)}
@image.deleter
def image(self):
try:
del self._image
except AttributeError:
pass
def set_image(self: E, *, url: MaybeEmpty[Any]) -> E: def set_image(self: E, *, url: MaybeEmpty[Any]) -> E:
"""Sets the image for the embed content. """Sets the image for the embed content.
@ -412,16 +426,7 @@ class Embed:
The source URL for the image. Only HTTP(S) is supported. The source URL for the image. Only HTTP(S) is supported.
""" """
if url is EmptyEmbed: self.image = url
try:
del self._image
except AttributeError:
pass
else:
self._image = {
'url': str(url),
}
return self return self
@property @property
@ -437,9 +442,23 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then :attr:`Empty` is returned.
""" """
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore return EmbedProxy(getattr(self, "_thumbnail", {})) # type: ignore
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E: @thumbnail.setter
def thumbnail(self, url: Any):
if url is EmptyEmbed:
del self.thumbnail
else:
self._thumbnail = {"url": str(url)}
@thumbnail.deleter
def thumbnail(self):
try:
del self._thumbnail
except AttributeError:
pass
def set_thumbnail(self, *, url: MaybeEmpty[Any]):
"""Sets the thumbnail for the embed content. """Sets the thumbnail for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@ -454,16 +473,7 @@ class Embed:
The source URL for the thumbnail. Only HTTP(S) is supported. The source URL for the thumbnail. Only HTTP(S) is supported.
""" """
if url is EmptyEmbed: self.thumbnail = url
try:
del self._thumbnail
except AttributeError:
pass
else:
self._thumbnail = {
'url': str(url),
}
return self return self
@property @property
@ -478,7 +488,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then :attr:`Empty` is returned.
""" """
return EmbedProxy(getattr(self, '_video', {})) # type: ignore return EmbedProxy(getattr(self, "_video", {})) # type: ignore
@property @property
def provider(self) -> _EmbedProviderProxy: def provider(self) -> _EmbedProviderProxy:
@ -488,7 +498,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then :attr:`Empty` is returned.
""" """
return EmbedProxy(getattr(self, '_provider', {})) # type: ignore return EmbedProxy(getattr(self, "_provider", {})) # type: ignore
@property @property
def author(self) -> _EmbedAuthorProxy: def author(self) -> _EmbedAuthorProxy:
@ -498,9 +508,11 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then :attr:`Empty` is returned.
""" """
return EmbedProxy(getattr(self, '_author', {})) # type: ignore return EmbedProxy(getattr(self, "_author", {})) # type: ignore
def set_author(self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E: def set_author(
self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed
) -> E:
"""Sets the author for the embed content. """Sets the author for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@ -517,14 +529,14 @@ class Embed:
""" """
self._author = { self._author = {
'name': str(name), "name": str(name),
} }
if url is not EmptyEmbed: if url is not EmptyEmbed:
self._author['url'] = str(url) self._author["url"] = str(url)
if icon_url is not EmptyEmbed: if icon_url is not EmptyEmbed:
self._author['icon_url'] = str(icon_url) self._author["icon_url"] = str(icon_url)
return self return self
@ -551,7 +563,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned. If the attribute has no value then :attr:`Empty` is returned.
""" """
return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore return [EmbedProxy(d) for d in getattr(self, "_fields", [])] # type: ignore
def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E: def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E:
"""Adds a field to the embed object. """Adds a field to the embed object.
@ -570,9 +582,9 @@ class Embed:
""" """
field = { field = {
'inline': inline, "inline": inline,
'name': str(name), "name": str(name),
'value': str(value), "value": str(value),
} }
try: try:
@ -603,9 +615,9 @@ class Embed:
""" """
field = { field = {
'inline': inline, "inline": inline,
'name': str(name), "name": str(name),
'value': str(value), "value": str(value),
} }
try: try:
@ -671,11 +683,11 @@ class Embed:
try: try:
field = self._fields[index] field = self._fields[index]
except (TypeError, IndexError, AttributeError): except (TypeError, IndexError, AttributeError):
raise IndexError('field index out of range') raise IndexError("field index out of range")
field['name'] = str(name) field["name"] = str(name)
field['value'] = str(value) field["value"] = str(value)
field['inline'] = inline field["inline"] = inline
return self return self
def to_dict(self) -> EmbedData: def to_dict(self) -> EmbedData:
@ -693,35 +705,35 @@ class Embed:
# deal with basic convenience wrappers # deal with basic convenience wrappers
try: try:
colour = result.pop('colour') colour = result.pop("colour")
except KeyError: except KeyError:
pass pass
else: else:
if colour: if colour:
result['color'] = colour.value result["color"] = colour.value
try: try:
timestamp = result.pop('timestamp') timestamp = result.pop("timestamp")
except KeyError: except KeyError:
pass pass
else: else:
if timestamp: if timestamp:
if timestamp.tzinfo: if timestamp.tzinfo:
result['timestamp'] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat() result["timestamp"] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat()
else: else:
result['timestamp'] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat() result["timestamp"] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat()
# add in the non raw attribute ones # add in the non raw attribute ones
if self.type: if self.type:
result['type'] = self.type result["type"] = self.type
if self.description: if self.description:
result['description'] = self.description result["description"] = self.description
if self.url: if self.url:
result['url'] = self.url result["url"] = self.url
if self.title: if self.title:
result['title'] = self.title result["title"] = self.title
return result # type: ignore return result # type: ignore

View File

@ -30,9 +30,7 @@ from .utils import SnowflakeList, snowflake_time, MISSING
from .partial_emoji import _EmojiTag, PartialEmoji from .partial_emoji import _EmojiTag, PartialEmoji
from .user import User from .user import User
__all__ = ( __all__ = ("Emoji",)
'Emoji',
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.emoji import Emoji as EmojiPayload from .types.emoji import Emoji as EmojiPayload
@ -72,6 +70,10 @@ class Emoji(_EmojiTag, AssetMixin):
Returns the emoji rendered for discord. Returns the emoji rendered for discord.
.. describe:: int(x)
Returns the emoji ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -94,16 +96,16 @@ class Emoji(_EmojiTag, AssetMixin):
""" """
__slots__: Tuple[str, ...] = ( __slots__: Tuple[str, ...] = (
'require_colons', "require_colons",
'animated', "animated",
'managed', "managed",
'id', "id",
'name', "name",
'_roles', "_roles",
'guild_id', "guild_id",
'_state', "_state",
'user', "user",
'available', "available",
) )
def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload): def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload):
@ -112,14 +114,14 @@ class Emoji(_EmojiTag, AssetMixin):
self._from_data(data) self._from_data(data)
def _from_data(self, emoji: EmojiPayload): def _from_data(self, emoji: EmojiPayload):
self.require_colons: bool = emoji.get('require_colons', False) self.require_colons: bool = emoji.get("require_colons", False)
self.managed: bool = emoji.get('managed', False) self.managed: bool = emoji.get("managed", False)
self.id: int = int(emoji['id']) # type: ignore self.id: int = int(emoji["id"]) # type: ignore
self.name: str = emoji['name'] # type: ignore self.name: str = emoji["name"] # type: ignore
self.animated: bool = emoji.get('animated', False) self.animated: bool = emoji.get("animated", False)
self.available: bool = emoji.get('available', True) self.available: bool = emoji.get("available", True)
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', []))) self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get("roles", [])))
user = emoji.get('user') user = emoji.get("user")
self.user: Optional[User] = User(state=self._state, data=user) if user else None self.user: Optional[User] = User(state=self._state, data=user) if user else None
def _to_partial(self) -> PartialEmoji: def _to_partial(self) -> PartialEmoji:
@ -127,18 +129,21 @@ class Emoji(_EmojiTag, AssetMixin):
def __iter__(self) -> Iterator[Tuple[str, Any]]: def __iter__(self) -> Iterator[Tuple[str, Any]]:
for attr in self.__slots__: for attr in self.__slots__:
if attr[0] != '_': if attr[0] != "_":
value = getattr(self, attr, None) value = getattr(self, attr, None)
if value is not None: if value is not None:
yield (attr, value) yield (attr, value)
def __str__(self) -> str: def __str__(self) -> str:
if self.animated: if self.animated:
return f'<a:{self.name}:{self.id}>' return f"<a:{self.name}:{self.id}>"
return f'<:{self.name}:{self.id}>' return f"<:{self.name}:{self.id}>"
def __int__(self) -> int:
return self.id
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>' return f"<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>"
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return isinstance(other, _EmojiTag) and self.id == other.id return isinstance(other, _EmojiTag) and self.id == other.id
@ -157,8 +162,8 @@ class Emoji(_EmojiTag, AssetMixin):
@property @property
def url(self) -> str: def url(self) -> str:
""":class:`str`: Returns the URL of the emoji.""" """:class:`str`: Returns the URL of the emoji."""
fmt = 'gif' if self.animated else 'png' fmt = "gif" if self.animated else "png"
return f'{Asset.BASE}/emojis/{self.id}.{fmt}' return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
@property @property
def roles(self) -> List[Role]: def roles(self) -> List[Role]:
@ -212,7 +217,9 @@ class Emoji(_EmojiTag, AssetMixin):
await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason) await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason)
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji: async def edit(
self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None
) -> Emoji:
r"""|coro| r"""|coro|
Edits the custom emoji. Edits the custom emoji.
@ -247,9 +254,9 @@ class Emoji(_EmojiTag, AssetMixin):
payload = {} payload = {}
if name is not MISSING: if name is not MISSING:
payload['name'] = name payload["name"] = name
if roles is not MISSING: if roles is not MISSING:
payload['roles'] = [role.id for role in roles] payload["roles"] = [role.id for role in roles]
data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason) data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
return Emoji(guild=self.guild, data=data, state=self._state) return Emoji(guild=self.guild, data=data, state=self._state)

View File

@ -27,41 +27,42 @@ from collections import namedtuple
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
__all__ = ( __all__ = (
'Enum', "Enum",
'ChannelType', "ChannelType",
'MessageType', "MessageType",
'VoiceRegion', "VoiceRegion",
'SpeakingState', "SpeakingState",
'VerificationLevel', "VerificationLevel",
'ContentFilter', "ContentFilter",
'Status', "Status",
'DefaultAvatar', "DefaultAvatar",
'AuditLogAction', "AuditLogAction",
'AuditLogActionCategory', "AuditLogActionCategory",
'UserFlags', "UserFlags",
'ActivityType', "ActivityType",
'NotificationLevel', "NotificationLevel",
'TeamMembershipState', "TeamMembershipState",
'WebhookType', "WebhookType",
'ExpireBehaviour', "ExpireBehaviour",
'ExpireBehavior', "ExpireBehavior",
'StickerType', "StickerType",
'StickerFormatType', "StickerFormatType",
'InviteTarget', "InviteTarget",
'VideoQualityMode', "VideoQualityMode",
'ComponentType', "ComponentType",
'ButtonStyle', "ButtonStyle",
'StagePrivacyLevel', "StagePrivacyLevel",
'InteractionType', "InteractionType",
'InteractionResponseType', "InteractionResponseType",
'NSFWLevel', "NSFWLevel",
"ProtocolURL",
) )
def _create_value_cls(name, comparable): def _create_value_cls(name, comparable):
cls = namedtuple('_EnumValue_' + name, 'name value') cls = namedtuple("_EnumValue_" + name, "name value")
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>' cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>"
cls.__str__ = lambda self: f'{name}.{self.name}' cls.__str__ = lambda self: f"{name}.{self.name}"
if comparable: if comparable:
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
@ -69,8 +70,9 @@ def _create_value_cls(name, comparable):
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
return cls return cls
def _is_descriptor(obj): def _is_descriptor(obj):
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__') return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
class EnumMeta(type): class EnumMeta(type):
@ -88,7 +90,7 @@ class EnumMeta(type):
value_cls = _create_value_cls(name, comparable) value_cls = _create_value_cls(name, comparable)
for key, value in list(attrs.items()): for key, value in list(attrs.items()):
is_descriptor = _is_descriptor(value) is_descriptor = _is_descriptor(value)
if key[0] == '_' and not is_descriptor: if key[0] == "_" and not is_descriptor:
continue continue
# Special case classmethod to just pass through # Special case classmethod to just pass through
@ -110,10 +112,10 @@ class EnumMeta(type):
member_mapping[key] = new_value member_mapping[key] = new_value
attrs[key] = new_value attrs[key] = new_value
attrs['_enum_value_map_'] = value_mapping attrs["_enum_value_map_"] = value_mapping
attrs['_enum_member_map_'] = member_mapping attrs["_enum_member_map_"] = member_mapping
attrs['_enum_member_names_'] = member_names attrs["_enum_member_names_"] = member_names
attrs['_enum_value_cls_'] = value_cls attrs["_enum_value_cls_"] = value_cls
actual_cls = super().__new__(cls, name, bases, attrs) actual_cls = super().__new__(cls, name, bases, attrs)
value_cls._actual_enum_cls_ = actual_cls # type: ignore value_cls._actual_enum_cls_ = actual_cls # type: ignore
return actual_cls return actual_cls
@ -128,7 +130,7 @@ class EnumMeta(type):
return len(cls._enum_member_names_) return len(cls._enum_member_names_)
def __repr__(cls): def __repr__(cls):
return f'<enum {cls.__name__}>' return f"<enum {cls.__name__}>"
@property @property
def __members__(cls): def __members__(cls):
@ -144,10 +146,10 @@ class EnumMeta(type):
return cls._enum_member_map_[key] return cls._enum_member_map_[key]
def __setattr__(cls, name, value): def __setattr__(cls, name, value):
raise TypeError('Enums are immutable.') raise TypeError("Enums are immutable.")
def __delattr__(cls, attr): def __delattr__(cls, attr):
raise TypeError('Enums are immutable') raise TypeError("Enums are immutable")
def __instancecheck__(self, instance): def __instancecheck__(self, instance):
# isinstance(x, Y) # isinstance(x, Y)
@ -215,29 +217,29 @@ class MessageType(Enum):
class VoiceRegion(Enum): class VoiceRegion(Enum):
us_west = 'us-west' us_west = "us-west"
us_east = 'us-east' us_east = "us-east"
us_south = 'us-south' us_south = "us-south"
us_central = 'us-central' us_central = "us-central"
eu_west = 'eu-west' eu_west = "eu-west"
eu_central = 'eu-central' eu_central = "eu-central"
singapore = 'singapore' singapore = "singapore"
london = 'london' london = "london"
sydney = 'sydney' sydney = "sydney"
amsterdam = 'amsterdam' amsterdam = "amsterdam"
frankfurt = 'frankfurt' frankfurt = "frankfurt"
brazil = 'brazil' brazil = "brazil"
hongkong = 'hongkong' hongkong = "hongkong"
russia = 'russia' russia = "russia"
japan = 'japan' japan = "japan"
southafrica = 'southafrica' southafrica = "southafrica"
south_korea = 'south-korea' south_korea = "south-korea"
india = 'india' india = "india"
europe = 'europe' europe = "europe"
dubai = 'dubai' dubai = "dubai"
vip_us_east = 'vip-us-east' vip_us_east = "vip-us-east"
vip_us_west = 'vip-us-west' vip_us_west = "vip-us-west"
vip_amsterdam = 'vip-amsterdam' vip_amsterdam = "vip-amsterdam"
def __str__(self): def __str__(self):
return self.value return self.value
@ -277,12 +279,12 @@ class ContentFilter(Enum, comparable=True):
class Status(Enum): class Status(Enum):
online = 'online' online = "online"
offline = 'offline' offline = "offline"
idle = 'idle' idle = "idle"
dnd = 'dnd' dnd = "dnd"
do_not_disturb = 'dnd' do_not_disturb = "dnd"
invisible = 'invisible' invisible = "invisible"
def __str__(self): def __str__(self):
return self.value return self.value
@ -415,33 +417,33 @@ class AuditLogAction(Enum):
def target_type(self) -> Optional[str]: def target_type(self) -> Optional[str]:
v = self.value v = self.value
if v == -1: if v == -1:
return 'all' return "all"
elif v < 10: elif v < 10:
return 'guild' return "guild"
elif v < 20: elif v < 20:
return 'channel' return "channel"
elif v < 30: elif v < 30:
return 'user' return "user"
elif v < 40: elif v < 40:
return 'role' return "role"
elif v < 50: elif v < 50:
return 'invite' return "invite"
elif v < 60: elif v < 60:
return 'webhook' return "webhook"
elif v < 70: elif v < 70:
return 'emoji' return "emoji"
elif v == 73: elif v == 73:
return 'channel' return "channel"
elif v < 80: elif v < 80:
return 'message' return "message"
elif v < 83: elif v < 83:
return 'integration' return "integration"
elif v < 90: elif v < 90:
return 'stage_instance' return "stage_instance"
elif v < 93: elif v < 93:
return 'sticker' return "sticker"
elif v < 113: elif v < 113:
return 'thread' return "thread"
class UserFlags(Enum): class UserFlags(Enum):
@ -528,6 +530,7 @@ class InteractionType(Enum):
ping = 1 ping = 1
application_command = 2 application_command = 2
component = 3 component = 3
application_command_autocomplete = 4
class InteractionResponseType(Enum): class InteractionResponseType(Enum):
@ -538,6 +541,7 @@ class InteractionResponseType(Enum):
deferred_channel_message = 5 # (with source) deferred_channel_message = 5 # (with source)
deferred_message_update = 6 # for components deferred_message_update = 6 # for components
message_update = 7 # for components message_update = 7 # for components
application_command_autocomplete_result = 8
class VideoQualityMode(Enum): class VideoQualityMode(Enum):
@ -589,12 +593,80 @@ class NSFWLevel(Enum, comparable=True):
age_restricted = 3 age_restricted = 3
T = TypeVar('T') class ProtocolURL(Enum):
# General
home = "discord://-/channels/@me/"
nitro = "discord://-/store"
apps = "discord://-/apps" # Breaks the client on windows (Shows download links for different OS)
guild_discovery = "discord://-/guild-discovery"
guild_create = "discord://-/guilds/create"
guild_invite = "discord://-/invite/{invite_code}"
# Settings
account_settings = "discord://-/settings/account"
profile_settings = "discord://-/settings/profile-customization"
privacy_settings = "discord://-/settings/privacy-and-safety"
safety_settings = "discord://-/settings/privacy-and-safety" # Alias
authorized_apps_settings = "discord://-/settings/authorized-apps"
connections_settings = "discord://-/settings/connections"
nitro_settings = "discord://-/settings/premium" # Same as store, but inside of settings
guild_premium_subscription = "discord://-/settings/premium-guild-subscription"
subscription_settings = "discord://-/settings/subscriptions"
gift_inventory_settings = "discord://-/settings/inventory"
billing_settings = "discord://-/settings/billing"
appearance_settings = "discord://-/settings/appearance"
accessibility_settings = "discord://-/settings/accessibility"
voice_video_settings = "discord://-/settings/voice"
text_images_settings = "discord://-/settings/text"
notifications_settings = "discord://-/settings/notifications"
keybinds_settings = "discord://-/settings/keybinds"
language_settings = "discord://-/settings/locale"
windows_settings = "discord://-/settings/windows" # Doesnt work if used on wrong platform
linux_settings = "discord://-/settings/linux" # Doesnt work if used on wrong platform
streamer_mode_settings = "discord://-/settings/streamer-mode"
advanced_settings = "discord://-/settings/advanced"
activity_status_settings = "discord://-/settings/activity-status"
game_overlay_settings = "discord://-/settings/overlay"
hypesquad_settings = "discord://-/settings/hypesquad-online"
changelogs = "discord://-/settings/changelogs"
# Doesn't work if you don't have it actually activated. Just blank screen.
experiments = "discord://-/settings/experiments"
developer_options = "discord://-/settings/developer-options" # Same as experiments
hotspot_options = "discord://-/settings/hotspot-options" # Same as experiments
# Users, Guilds, and DMs
user_profile = "discord://-/users/{user_id}"
dm_channel = "discord://-/channels/@me/{channel_id}"
dm_message = "discord://-/channels/@me/{channel_id}/{message_id}"
guild_channel = "discord://-/channels/{guild_id}/{channel_id}"
guild_message = "discord://-/channels/{guild_id}/{channel_id}/{message_id}"
guild_membership_screening = "discord://-/member-verification/{guild_id}"
# Library
games_library = "discord://-/library"
library_settings = "discord://-/library/settings"
def __str__(self) -> str:
return self.value
def format(self, **kwargs: Any) -> str:
return self.value.format(**kwargs)
T = TypeVar("T")
def create_unknown_value(cls: Type[T], val: Any) -> T: def create_unknown_value(cls: Type[T], val: Any) -> T:
value_cls = cls._enum_value_cls_ # type: ignore value_cls = cls._enum_value_cls_ # type: ignore
name = f'unknown_{val}' name = f"unknown_{val}"
return value_cls(name=name, value=val) return value_cls(name=name, value=val)

View File

@ -38,20 +38,20 @@ if TYPE_CHECKING:
from .interactions import Interaction from .interactions import Interaction
__all__ = ( __all__ = (
'DiscordException', "DiscordException",
'ClientException', "ClientException",
'NoMoreItems', "NoMoreItems",
'GatewayNotFound', "GatewayNotFound",
'HTTPException', "HTTPException",
'Forbidden', "Forbidden",
'NotFound', "NotFound",
'DiscordServerError', "DiscordServerError",
'InvalidData', "InvalidData",
'InvalidArgument', "InvalidArgument",
'LoginFailure', "LoginFailure",
'ConnectionClosed', "ConnectionClosed",
'PrivilegedIntentsRequired', "PrivilegedIntentsRequired",
'InteractionResponded', "InteractionResponded",
) )
@ -83,22 +83,22 @@ class GatewayNotFound(DiscordException):
"""An exception that is raised when the gateway for Discord could not be found""" """An exception that is raised when the gateway for Discord could not be found"""
def __init__(self): def __init__(self):
message = 'The gateway to connect to discord was not found.' message = "The gateway to connect to discord was not found."
super().__init__(message) super().__init__(message)
def _flatten_error_dict(d: Dict[str, Any], key: str = '') -> Dict[str, str]: def _flatten_error_dict(d: Dict[str, Any], key: str = "") -> Dict[str, str]:
items: List[Tuple[str, str]] = [] items: List[Tuple[str, str]] = []
for k, v in d.items(): for k, v in d.items():
new_key = key + '.' + k if key else k new_key = key + "." + k if key else k
if isinstance(v, dict): if isinstance(v, dict):
try: try:
_errors: List[Dict[str, Any]] = v['_errors'] _errors: List[Dict[str, Any]] = v["_errors"]
except KeyError: except KeyError:
items.extend(_flatten_error_dict(v, new_key).items()) items.extend(_flatten_error_dict(v, new_key).items())
else: else:
items.append((new_key, ' '.join(x.get('message', '') for x in _errors))) items.append((new_key, " ".join(x.get("message", "") for x in _errors)))
else: else:
items.append((new_key, v)) items.append((new_key, v))
@ -129,22 +129,22 @@ class HTTPException(DiscordException):
self.code: int self.code: int
self.text: str self.text: str
if isinstance(message, dict): if isinstance(message, dict):
self.code = message.get('code', 0) self.code = message.get("code", 0)
base = message.get('message', '') base = message.get("message", "")
errors = message.get('errors') errors = message.get("errors")
if errors: if errors:
errors = _flatten_error_dict(errors) errors = _flatten_error_dict(errors)
helpful = '\n'.join('In %s: %s' % t for t in errors.items()) helpful = "\n".join("In %s: %s" % t for t in errors.items())
self.text = base + '\n' + helpful self.text = base + "\n" + helpful
else: else:
self.text = base self.text = base
else: else:
self.text = message or '' self.text = message or ""
self.code = 0 self.code = 0
fmt = '{0.status} {0.reason} (error code: {1})' fmt = "{0.status} {0.reason} (error code: {1})"
if len(self.text): if len(self.text):
fmt += ': {2}' fmt += ": {2}"
super().__init__(fmt.format(self.response, self.code, self.text)) super().__init__(fmt.format(self.response, self.code, self.text))
@ -226,9 +226,9 @@ class ConnectionClosed(ClientException):
# reconfigured to subclass ClientException for users # reconfigured to subclass ClientException for users
self.code: int = code or socket.close_code or -1 self.code: int = code or socket.close_code or -1
# aiohttp doesn't seem to consistently provide close reason # aiohttp doesn't seem to consistently provide close reason
self.reason: str = '' self.reason: str = ""
self.shard_id: Optional[int] = shard_id self.shard_id: Optional[int] = shard_id
super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}') super().__init__(f"Shard ID {self.shard_id} WebSocket closed with {self.code}")
class PrivilegedIntentsRequired(ClientException): class PrivilegedIntentsRequired(ClientException):
@ -250,10 +250,10 @@ class PrivilegedIntentsRequired(ClientException):
def __init__(self, shard_id: Optional[int]): def __init__(self, shard_id: Optional[int]):
self.shard_id: Optional[int] = shard_id self.shard_id: Optional[int] = shard_id
msg = ( msg = (
'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the ' "Shard ID %s is requesting privileged intents that have not been explicitly enabled in the "
'developer portal. It is recommended to go to https://discord.com/developers/applications/ ' "developer portal. It is recommended to go to https://discord.com/developers/applications/ "
'and explicitly enable the privileged intents within your application\'s page. If this is not ' "and explicitly enable the privileged intents within your application's page. If this is not "
'possible, then consider disabling the privileged intents instead.' "possible, then consider disabling the privileged intents instead."
) )
super().__init__(msg % shard_id) super().__init__(msg % shard_id)
@ -274,4 +274,4 @@ class InteractionResponded(ClientException):
def __init__(self, interaction: Interaction): def __init__(self, interaction: Interaction):
self.interaction: Interaction = interaction self.interaction: Interaction = interaction
super().__init__('This interaction has already been responded to before') super().__init__("This interaction has already been responded to before")

View File

@ -31,7 +31,7 @@ if TYPE_CHECKING:
from .cog import Cog from .cog import Cog
from .errors import CommandError from .errors import CommandError
T = TypeVar('T') T = TypeVar("T")
Coro = Coroutine[Any, Any, T] Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]] MaybeCoro = Union[T, Coro[T]]
@ -39,7 +39,9 @@ CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]] Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]] Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]] Error = Union[
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]
]
# This is merely a tag type to avoid circular import issues. # This is merely a tag type to avoid circular import issues.

View File

@ -28,18 +28,44 @@ from __future__ import annotations
import asyncio import asyncio
import collections import collections
import collections.abc import collections.abc
from functools import cached_property
import inspect import inspect
import importlib.util import importlib.util
import sys import sys
import traceback import traceback
import types import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union from collections import defaultdict
from discord.http import HTTPClient
from typing import (
Any,
Callable,
Iterable,
Tuple,
cast,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
)
import discord import discord
from discord.types.interactions import (
ApplicationCommandInteractionData,
ApplicationCommandInteractionDataOption,
EditApplicationCommand,
_ApplicationCommandInteractionDataOptionString,
)
from .core import GroupMixin from .core import GroupMixin
from .view import StringView from .converter import Greedy
from .view import StringView, supported_quotes
from .context import Context from .context import Context
from .flags import FlagConverter
from . import errors from . import errors
from .help import HelpCommand, DefaultHelpCommand from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog from .cog import Cog
@ -47,24 +73,67 @@ from .cog import Cog
if TYPE_CHECKING: if TYPE_CHECKING:
import importlib.machinery import importlib.machinery
from discord.role import Role
from discord.message import Message from discord.message import Message
from discord.abc import PartialMessageableChannel
from ._types import ( from ._types import (
Check, Check,
CoroFunc, CoroFunc,
) )
__all__ = ( __all__ = (
'when_mentioned', "when_mentioned",
'when_mentioned_or', "when_mentioned_or",
'Bot', "Bot",
'AutoShardedBot', "AutoShardedBot",
) )
MISSING: Any = discord.utils.MISSING MISSING: Any = discord.utils.MISSING
T = TypeVar('T') T = TypeVar("T")
CFT = TypeVar('CFT', bound='CoroFunc') CFT = TypeVar("CFT", bound="CoroFunc")
CXT = TypeVar('CXT', bound='Context') CXT = TypeVar("CXT", bound="Context")
class _FakeSlashMessage(discord.PartialMessage):
activity = application = edited_at = reference = webhook_id = None
attachments = components = reactions = stickers = []
tts = False
raw_mentions = discord.Message.raw_mentions
clean_content = discord.Message.clean_content
channel_mentions = discord.Message.channel_mentions
raw_role_mentions = discord.Message.raw_role_mentions
raw_channel_mentions = discord.Message.raw_channel_mentions
author: Union[discord.User, discord.Member]
@classmethod
def from_interaction(
cls, interaction: discord.Interaction, channel: Union[discord.TextChannel, discord.DMChannel, discord.Thread]
):
self = cls(channel=channel, id=interaction.id)
assert interaction.user is not None
self.author = interaction.user
return self
@cached_property
def mentions(self) -> List[Union[discord.Member, discord.User]]:
client = self._state._get_client()
if self.guild:
ensure_user = lambda id: self.guild.get_member(id) or client.get_user(id) # type: ignore
else:
ensure_user = client.get_user
return discord.utils._unique(filter(None, map(ensure_user, self.raw_mentions)))
@cached_property
def role_mentions(self) -> List[Role]:
if self.guild is None:
return []
return discord.utils._unique(filter(None, map(self.guild.get_role, self.raw_role_mentions)))
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned. """A callable that implements a command prefix equivalent to being mentioned.
@ -72,7 +141,8 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
""" """
# bot.user will never be None when this is called # bot.user will never be None when this is called
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
"""A callable that implements when mentioned or other prefixes provided. """A callable that implements when mentioned or other prefixes provided.
@ -103,6 +173,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
---------- ----------
:func:`.when_mentioned` :func:`.when_mentioned`
""" """
def inner(bot, msg): def inner(bot, msg):
r = list(prefixes) r = list(prefixes)
r = when_mentioned(bot, msg) + r r = when_mentioned(bot, msg) + r
@ -110,19 +181,66 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
return inner return inner
def _is_submodule(parent: str, child: str) -> bool: def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".") return parent == child or child.startswith(parent + ".")
def _unwrap_slash_groups(
data: ApplicationCommandInteractionData,
) -> Tuple[str, Dict[str, ApplicationCommandInteractionDataOption]]:
command_name = data["name"]
command_options: Any = data.get("options") or []
while True:
try:
option = next(o for o in command_options if o["type"] in {1, 2})
except StopIteration:
return command_name, {o["name"]: o for o in command_options}
else:
command_name += f' {option["name"]}'
command_options = option.get("options") or []
def _quote_string_safe(string: str) -> str:
# we need to quote this string otherwise we may spill into
# other parameters and cause all kinds of trouble, as many
# quotes are supported and some may be in the option, we
# loop through all supported quotes and if neither open or
# close are in the string, we add them
for open, close in supported_quotes.items():
if open not in string and close not in string:
return f"{open}{string}{close}"
# all supported quotes are in the message and we cannot add any
# safely, very unlikely but still got to be covered
raise errors.UnexpectedQuoteError(string)
class _DefaultRepr: class _DefaultRepr:
def __repr__(self): def __repr__(self):
return '<default-help-command>' return "<default-help-command>"
_default = _DefaultRepr() _default = _DefaultRepr()
class BotBase(GroupMixin): class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options): def __init__(
super().__init__(**options) self,
command_prefix,
help_command=_default,
description=None,
*,
intents: discord.Intents,
message_commands: bool = True,
slash_commands: bool = False,
**options,
):
super().__init__(**options, intents=intents)
self.command_prefix = command_prefix self.command_prefix = command_prefix
self.slash_commands = slash_commands
self.message_commands = message_commands
self.extra_events: Dict[str, List[CoroFunc]] = {} self.extra_events: Dict[str, List[CoroFunc]] = {}
self.__cogs: Dict[str, Cog] = {} self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {} self.__extensions: Dict[str, types.ModuleType] = {}
@ -131,16 +249,20 @@ class BotBase(GroupMixin):
self._before_invoke = None self._before_invoke = None
self._after_invoke = None self._after_invoke = None
self._help_command = None self._help_command = None
self.description = inspect.cleandoc(description) if description else '' self.description = inspect.cleandoc(description) if description else ""
self.owner_id = options.get('owner_id') self.owner_id = options.get("owner_id")
self.owner_ids = options.get('owner_ids', set()) self.owner_ids = options.get("owner_ids", set())
self.strip_after_prefix = options.get('strip_after_prefix', False) self.strip_after_prefix = options.get("strip_after_prefix", False)
self.slash_command_guilds: Optional[Iterable[int]] = options.get("slash_command_guilds", None)
if self.owner_id and self.owner_ids: if self.owner_id and self.owner_ids:
raise TypeError('Both owner_id and owner_ids are set.') raise TypeError("Both owner_id and owner_ids are set.")
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection): if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}') raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}")
if not (message_commands or slash_commands):
raise ValueError("Both message_commands and slash_commands are disabled.")
if help_command is _default: if help_command is _default:
self.help_command = DefaultHelpCommand() self.help_command = DefaultHelpCommand()
@ -152,10 +274,59 @@ class BotBase(GroupMixin):
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client # super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name ev = "on_" + event_name
for event in self.extra_events.get(ev, []): for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) # type: ignore self._schedule_event(event, ev, *args, **kwargs) # type: ignore
async def setup(self):
await self.create_slash_commands()
async def create_slash_commands(self):
commands: defaultdict[Optional[int], List[EditApplicationCommand]] = defaultdict(list)
for command in self.commands:
if command.hidden or (command.slash_command is None and not self.slash_commands):
continue
try:
payload = command.to_application_command()
except Exception:
raise errors.ApplicationCommandRegistrationError(command)
if payload is None:
continue
guilds = command.slash_command_guilds or self.slash_command_guilds
if guilds is None:
commands[None].append(payload)
else:
for guild in guilds:
commands[guild].append(payload)
http: HTTPClient = self.http # type: ignore
global_commands = commands.pop(None, None)
application_id = self.application_id or (await self.application_info()).id # type: ignore
if global_commands is not None:
if self.slash_command_guilds is None:
await http.bulk_upsert_global_commands(
payload=global_commands,
application_id=application_id,
)
else:
for guild in self.slash_command_guilds:
await http.bulk_upsert_guild_commands(
guild_id=guild,
payload=global_commands,
application_id=application_id,
)
for guild, guild_commands in commands.items():
assert guild is not None
await http.bulk_upsert_guild_commands(
guild_id=guild,
payload=guild_commands,
application_id=application_id,
)
@discord.utils.copy_doc(discord.Client.close) @discord.utils.copy_doc(discord.Client.close)
async def close(self) -> None: async def close(self) -> None:
for extension in tuple(self.__extensions): for extension in tuple(self.__extensions):
@ -182,7 +353,7 @@ class BotBase(GroupMixin):
This only fires if you do not specify any listeners for command error. This only fires if you do not specify any listeners for command error.
""" """
if self.extra_events.get('on_command_error', None): if self.extra_events.get("on_command_error", None):
return return
command = context.command command = context.command
@ -193,7 +364,7 @@ class BotBase(GroupMixin):
if cog and cog.has_error_handler(): if cog and cog.has_error_handler():
return return
print(f'Ignoring exception in command {context.command}:', file=sys.stderr) print(f"Ignoring exception in command {context.command}:", file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
# global check registration # global check registration
@ -344,14 +515,59 @@ class BotBase(GroupMixin):
elif self.owner_ids: elif self.owner_ids:
return user.id in self.owner_ids return user.id in self.owner_ids
else: else:
# Populate the used fields, then retry the check. This is only done at-most once in the bot lifetime.
await self.populate_owners()
return await self.is_owner(user)
app = await self.application_info() # type: ignore async def try_owners(self) -> List[discord.User]:
if app.team: """|coro|
self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids Returns a list of :class:`~discord.User` representing the owners of the bot.
It uses the :attr:`owner_id` and :attr:`owner_ids`, if set.
.. versionadded:: 2.0
The function also checks if the application is team-owned if
:attr:`owner_ids` is not set.
Returns
--------
List[:class:`~discord.User`]
List of owners of the bot.
"""
if self.owner_id:
owner = await self.try_user(self.owner_id)
if owner:
return [owner]
else: else:
self.owner_id = owner_id = app.owner.id return []
return user.id == owner_id
elif self.owner_ids:
owners = []
for owner_id in self.owner_ids:
owner = await self.try_user(owner_id)
if owner:
owners.append(owner)
return owners
else:
# We didn't have owners cached yet, cache them and retry.
await self.populate_owners()
return await self.try_owners()
async def populate_owners(self):
"""|coro|
Populate the :attr:`owner_id` and :attr:`owner_ids` through the use of :meth:`~.Bot.application_info`.
.. versionadded:: 2.0
"""
app = await self.application_info() # type: ignore
if app.team:
self.owner_ids = {m.id for m in app.team.members}
else:
self.owner_id = app.owner.id
def before_invoke(self, coro: CFT) -> CFT: def before_invoke(self, coro: CFT) -> CFT:
"""A decorator that registers a coroutine as a pre-invoke hook. """A decorator that registers a coroutine as a pre-invoke hook.
@ -380,7 +596,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
if not asyncio.iscoroutinefunction(coro): if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.') raise TypeError("The pre-invoke hook must be a coroutine.")
self._before_invoke = coro self._before_invoke = coro
return coro return coro
@ -413,7 +629,7 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
if not asyncio.iscoroutinefunction(coro): if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.') raise TypeError("The post-invoke hook must be a coroutine.")
self._after_invoke = coro self._after_invoke = coro
return coro return coro
@ -445,7 +661,7 @@ class BotBase(GroupMixin):
name = func.__name__ if name is MISSING else name name = func.__name__ if name is MISSING else name
if not asyncio.iscoroutinefunction(func): if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines') raise TypeError("Listeners must be coroutines")
if name in self.extra_events: if name in self.extra_events:
self.extra_events[name].append(func) self.extra_events[name].append(func)
@ -541,14 +757,14 @@ class BotBase(GroupMixin):
""" """
if not isinstance(cog, Cog): if not isinstance(cog, Cog):
raise TypeError('cogs must derive from Cog') raise TypeError("cogs must derive from Cog")
cog_name = cog.__cog_name__ cog_name = cog.__cog_name__
existing = self.__cogs.get(cog_name) existing = self.__cogs.get(cog_name)
if existing is not None: if existing is not None:
if not override: if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded') raise discord.ClientException(f"Cog named {cog_name!r} already loaded")
self.remove_cog(cog_name) self.remove_cog(cog_name)
cog = cog._inject(self) cog = cog._inject(self)
@ -636,7 +852,7 @@ class BotBase(GroupMixin):
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
try: try:
func = getattr(lib, 'teardown') func = getattr(lib, "teardown")
except AttributeError: except AttributeError:
pass pass
else: else:
@ -663,7 +879,7 @@ class BotBase(GroupMixin):
raise errors.ExtensionFailed(key, e) from e raise errors.ExtensionFailed(key, e) from e
try: try:
setup = getattr(lib, 'setup') setup = getattr(lib, "setup")
except AttributeError: except AttributeError:
del sys.modules[key] del sys.modules[key]
raise errors.NoEntryPointError(key) raise errors.NoEntryPointError(key)
@ -813,11 +1029,7 @@ class BotBase(GroupMixin):
raise errors.ExtensionNotLoaded(name) raise errors.ExtensionNotLoaded(name)
# get the previous module states from sys modules # get the previous module states from sys modules
modules = { modules = {name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name)}
name: module
for name, module in sys.modules.items()
if _is_submodule(lib.__name__, name)
}
try: try:
# Unload and then load the module... # Unload and then load the module...
@ -850,7 +1062,7 @@ class BotBase(GroupMixin):
def help_command(self, value: Optional[HelpCommand]) -> None: def help_command(self, value: Optional[HelpCommand]) -> None:
if value is not None: if value is not None:
if not isinstance(value, HelpCommand): if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand') raise TypeError("help_command must be a subclass of HelpCommand")
if self._help_command is not None: if self._help_command is not None:
self._help_command._remove_from_bot(self) self._help_command._remove_from_bot(self)
self._help_command = value self._help_command = value
@ -880,6 +1092,9 @@ class BotBase(GroupMixin):
A list of prefixes or a single prefix that the bot is A list of prefixes or a single prefix that the bot is
listening for. listening for.
""" """
if isinstance(message, _FakeSlashMessage):
return "/"
prefix = ret = self.command_prefix prefix = ret = self.command_prefix
if callable(prefix): if callable(prefix):
ret = await discord.utils.maybe_coroutine(prefix, self, message) ret = await discord.utils.maybe_coroutine(prefix, self, message)
@ -893,8 +1108,10 @@ class BotBase(GroupMixin):
if isinstance(ret, collections.abc.Iterable): if isinstance(ret, collections.abc.Iterable):
raise raise
raise TypeError("command_prefix must be plain string, iterable of strings, or callable " raise TypeError(
f"returning either of these, not {ret.__class__.__name__}") "command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}"
)
if not ret: if not ret:
raise ValueError("Iterable command_prefix must contain at least one prefix") raise ValueError("Iterable command_prefix must contain at least one prefix")
@ -954,14 +1171,18 @@ class BotBase(GroupMixin):
except TypeError: except TypeError:
if not isinstance(prefix, list): if not isinstance(prefix, list):
raise TypeError("get_prefix must return either a string or a list of string, " raise TypeError(
f"not {prefix.__class__.__name__}") "get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}"
)
# It's possible a bad command_prefix got us here. # It's possible a bad command_prefix got us here.
for value in prefix: for value in prefix:
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("Iterable command_prefix or list returned from get_prefix must " raise TypeError(
f"contain only strings, not {value.__class__.__name__}") "Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}"
)
# Getting here shouldn't happen # Getting here shouldn't happen
raise raise
@ -988,19 +1209,19 @@ class BotBase(GroupMixin):
The invocation context to invoke. The invocation context to invoke.
""" """
if ctx.command is not None: if ctx.command is not None:
self.dispatch('command', ctx) self.dispatch("command", ctx)
try: try:
if await self.can_run(ctx, call_once=True): if await self.can_run(ctx, call_once=True):
await ctx.command.invoke(ctx) await ctx.command.invoke(ctx)
else: else:
raise errors.CheckFailure('The global check once functions failed.') raise errors.CheckFailure("The global check once functions failed.")
except errors.CommandError as exc: except errors.CommandError as exc:
await ctx.command.dispatch_error(ctx, exc) await ctx.command.dispatch_error(ctx, exc)
else: else:
self.dispatch('command_completion', ctx) self.dispatch("command_completion", ctx)
elif ctx.invoked_with: elif ctx.invoked_with:
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc) self.dispatch("command_error", ctx, exc)
async def process_commands(self, message: Message) -> None: async def process_commands(self, message: Message) -> None:
"""|coro| """|coro|
@ -1030,9 +1251,95 @@ class BotBase(GroupMixin):
ctx = await self.get_context(message) ctx = await self.get_context(message)
await self.invoke(ctx) await self.invoke(ctx)
async def process_slash_commands(self, interaction: discord.Interaction):
"""|coro|
This function processes a slash command interaction into a usable
message and calls :meth:`.process_commands` based on it. Without this
coroutine slash commands will not be triggered.
By default, this coroutine is called inside the :func:`.on_interaction`
event. If you choose to override the :func:`.on_interaction` event,
then you should invoke this coroutine as well.
.. versionadded:: 2.0
Parameters
-----------
interaction: :class:`discord.Interaction`
The interaction to process slash commands for.
"""
if interaction.type != discord.InteractionType.application_command:
return
interaction.data = cast(ApplicationCommandInteractionData, interaction.data)
command_name, command_options = _unwrap_slash_groups(interaction.data)
command = self.get_command(command_name)
if command is None:
raise errors.CommandNotFound(f'Command "{command_name}" is not found')
# Ensure the interaction channel is usable
channel = interaction.channel
if channel is None or isinstance(channel, discord.PartialMessageable):
if interaction.guild is None:
assert interaction.user is not None
channel = await interaction.user.create_dm()
elif interaction.channel_id is not None:
channel = await interaction.guild.fetch_channel(interaction.channel_id)
else:
return # cannot do anything without stable channel
# Make our fake message so we can pass it to ext.commands
message: discord.Message = _FakeSlashMessage.from_interaction(interaction, channel) # type: ignore
message.content = f"/{command_name}"
# Add arguments to fake message content, in the right order
ignore_params: List[inspect.Parameter] = []
for name, param in command.clean_params.items():
if inspect.isclass(param.annotation) and issubclass(param.annotation, FlagConverter):
for name, flag in param.annotation.get_flags().items():
option = command_options.get(name)
if option is None:
if flag.required:
raise errors.MissingRequiredFlag(flag)
else:
prefix = param.annotation.__commands_flag_prefix__
delimiter = param.annotation.__commands_flag_delimiter__
message.content += f" {prefix}{name}{delimiter}{option['value']}" # type: ignore
continue
option = command_options.get(name)
if option is None:
if param.default is param.empty and not command._is_typing_optional(param.annotation):
raise errors.MissingRequiredArgument(param)
else:
ignore_params.append(param)
elif (
option["type"] == 3
and not isinstance(param.annotation, Greedy)
and param.kind in {param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY}
):
# String with space in without "consume rest"
option = cast(_ApplicationCommandInteractionDataOptionString, option)
message.content += f" {_quote_string_safe(option['value'])}"
else:
message.content += f' {option.get("value", "")}'
ctx = await self.get_context(message)
ctx._ignored_params = ignore_params
ctx.interaction = interaction
await self.invoke(ctx)
async def on_message(self, message): async def on_message(self, message):
await self.process_commands(message) await self.process_commands(message)
async def on_interaction(self, interaction: discord.Interaction):
await self.process_slash_commands(interaction)
class Bot(BotBase, discord.Client): class Bot(BotBase, discord.Client):
"""Represents a discord bot. """Represents a discord bot.
@ -1075,7 +1382,7 @@ class Bot(BotBase, discord.Client):
when passing an empty string, it should always be last as no prefix when passing an empty string, it should always be last as no prefix
after it will be matched. after it will be matched.
case_insensitive: :class:`bool` case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``. This Whether the commands should be case insensitive. Defaults to ``True``. This
attribute does not carry over to groups. You must set it to every group if attribute does not carry over to groups. You must set it to every group if
you require group commands to be case insensitive as well. you require group commands to be case insensitive as well.
description: :class:`str` description: :class:`str`
@ -1102,11 +1409,36 @@ class Bot(BotBase, discord.Client):
the ``command_prefix`` is set to ``!``. Defaults to ``False``. the ``command_prefix`` is set to ``!``. Defaults to ``False``.
.. versionadded:: 1.7 .. versionadded:: 1.7
message_commands: Optional[:class:`bool`]
Whether to process commands based on messages.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``message_command`` parameter
.. versionadded:: 2.0
slash_commands: Optional[:class:`bool`]
Whether to upload and process slash commands.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``slash_command`` parameter
.. versionadded:: 2.0
slash_command_guilds: Optional[:class:`List[int]`]
If this is set, only upload slash commands to these guild IDs.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``slash_command_guilds`` parameter
.. versionadded:: 2.0
""" """
pass pass
class AutoShardedBot(BotBase, discord.AutoShardedClient): class AutoShardedBot(BotBase, discord.AutoShardedClient):
"""This is similar to :class:`.Bot` except that it is inherited from """This is similar to :class:`.Bot` except that it is inherited from
:class:`discord.AutoShardedClient` instead. :class:`discord.AutoShardedClient` instead.
""" """
pass pass

View File

@ -36,15 +36,16 @@ if TYPE_CHECKING:
from .core import Command from .core import Command
__all__ = ( __all__ = (
'CogMeta', "CogMeta",
'Cog', "Cog",
) )
CogT = TypeVar('CogT', bound='Cog') CogT = TypeVar("CogT", bound="Cog")
FuncT = TypeVar('FuncT', bound=Callable[..., Any]) FuncT = TypeVar("FuncT", bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING MISSING: Any = discord.utils.MISSING
class CogMeta(type): class CogMeta(type):
"""A metaclass for defining a cog. """A metaclass for defining a cog.
@ -104,6 +105,7 @@ class CogMeta(type):
async def bar(self, ctx): async def bar(self, ctx):
pass # hidden -> False pass # hidden -> False
""" """
__cog_name__: str __cog_name__: str
__cog_settings__: Dict[str, Any] __cog_settings__: Dict[str, Any]
__cog_commands__: List[Command] __cog_commands__: List[Command]
@ -111,17 +113,17 @@ class CogMeta(type):
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
name, bases, attrs = args name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name) attrs["__cog_name__"] = kwargs.pop("name", name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) attrs["__cog_settings__"] = kwargs.pop("command_attrs", {})
description = kwargs.pop('description', None) description = kwargs.pop("description", None)
if description is None: if description is None:
description = inspect.cleandoc(attrs.get('__doc__', '')) description = inspect.cleandoc(attrs.get("__doc__", ""))
attrs['__cog_description__'] = description attrs["__cog_description__"] = description
commands = {} commands = {}
listeners = {} listeners = {}
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})' no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})"
new_cls = super().__new__(cls, name, bases, attrs, **kwargs) new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
for base in reversed(new_cls.__mro__): for base in reversed(new_cls.__mro__):
@ -136,21 +138,21 @@ class CogMeta(type):
value = value.__func__ value = value.__func__
if isinstance(value, _BaseCommand): if isinstance(value, _BaseCommand):
if is_static_method: if is_static_method:
raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.') raise TypeError(f"Command in method {base}.{elem!r} must not be staticmethod.")
if elem.startswith(('cog_', 'bot_')): if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem)) raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value commands[elem] = value
elif inspect.iscoroutinefunction(value): elif inspect.iscoroutinefunction(value):
try: try:
getattr(value, '__cog_listener__') getattr(value, "__cog_listener__")
except AttributeError: except AttributeError:
continue continue
else: else:
if elem.startswith(('cog_', 'bot_')): if elem.startswith(("cog_", "bot_")):
raise TypeError(no_bot_cog.format(base, elem)) raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
listeners_as_list = [] listeners_as_list = []
for listener in listeners.values(): for listener in listeners.values():
@ -169,10 +171,12 @@ class CogMeta(type):
def qualified_name(cls) -> str: def qualified_name(cls) -> str:
return cls.__cog_name__ return cls.__cog_name__
def _cog_special_method(func: FuncT) -> FuncT: def _cog_special_method(func: FuncT) -> FuncT:
func.__cog_special_method__ = None func.__cog_special_method__ = None
return func return func
class Cog(metaclass=CogMeta): class Cog(metaclass=CogMeta):
"""The base class that all cogs must inherit from. """The base class that all cogs must inherit from.
@ -183,6 +187,7 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta` When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here. are equally valid here.
""" """
__cog_name__: ClassVar[str] __cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]] __cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]] __cog_commands__: ClassVar[List[Command]]
@ -199,10 +204,7 @@ class Cog(metaclass=CogMeta):
# r.e type ignore, type-checker complains about overriding a ClassVar # r.e type ignore, type-checker complains about overriding a ClassVar
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = { lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
cmd.qualified_name: cmd
for cmd in self.__cog_commands__
}
# Update the Command instances dynamically as well # Update the Command instances dynamically as well
for command in self.__cog_commands__: for command in self.__cog_commands__:
@ -255,6 +257,7 @@ class Cog(metaclass=CogMeta):
A command or group from the cog. A command or group from the cog.
""" """
from .core import GroupMixin from .core import GroupMixin
for command in self.__cog_commands__: for command in self.__cog_commands__:
if command.parent is None: if command.parent is None:
yield command yield command
@ -274,7 +277,7 @@ class Cog(metaclass=CogMeta):
@classmethod @classmethod
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]: def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
"""Return None if the method is not overridden. Otherwise returns the overridden method.""" """Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method) return getattr(method.__func__, "__cog_special_method__", method)
@classmethod @classmethod
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
@ -296,14 +299,14 @@ class Cog(metaclass=CogMeta):
""" """
if name is not MISSING and not isinstance(name, str): if name is not MISSING and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.') raise TypeError(f"Cog.listener expected str but received {name.__class__.__name__!r} instead.")
def decorator(func: FuncT) -> FuncT: def decorator(func: FuncT) -> FuncT:
actual = func actual = func
if isinstance(actual, staticmethod): if isinstance(actual, staticmethod):
actual = actual.__func__ actual = actual.__func__
if not inspect.iscoroutinefunction(actual): if not inspect.iscoroutinefunction(actual):
raise TypeError('Listener function must be a coroutine function.') raise TypeError("Listener function must be a coroutine function.")
actual.__cog_listener__ = True actual.__cog_listener__ = True
to_assign = name or actual.__name__ to_assign = name or actual.__name__
try: try:
@ -315,6 +318,7 @@ class Cog(metaclass=CogMeta):
# to pick it up but the metaclass unfurls the function and # to pick it up but the metaclass unfurls the function and
# thus the assignments need to be on the actual function # thus the assignments need to be on the actual function
return func return func
return decorator return decorator
def has_error_handler(self) -> bool: def has_error_handler(self) -> bool:
@ -322,7 +326,7 @@ class Cog(metaclass=CogMeta):
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__') return not hasattr(self.cog_command_error.__func__, "__cog_special_method__")
@_cog_special_method @_cog_special_method
def cog_unload(self) -> None: def cog_unload(self) -> None:

View File

@ -22,16 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import asyncio
import inspect import inspect
import re import re
from datetime import timedelta
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union from typing import Any, Dict, Generic, List, Literal, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union, overload
import discord.abc import discord.abc
import discord.utils import discord.utils
from discord.message import Message from discord.message import Message
from discord import Permissions
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
@ -41,6 +43,8 @@ if TYPE_CHECKING:
from discord.member import Member from discord.member import Member
from discord.state import ConnectionState from discord.state import ConnectionState
from discord.user import ClientUser, User from discord.user import ClientUser, User
from discord.webhook import WebhookMessage
from discord.interactions import Interaction
from discord.voice_client import VoiceProtocol from discord.voice_client import VoiceProtocol
from .bot import Bot, AutoShardedBot from .bot import Bot, AutoShardedBot
@ -49,21 +53,19 @@ if TYPE_CHECKING:
from .help import HelpCommand from .help import HelpCommand
from .view import StringView from .view import StringView
__all__ = ( __all__ = ("Context",)
'Context',
)
MISSING: Any = discord.utils.MISSING MISSING: Any = discord.utils.MISSING
T = TypeVar('T') T = TypeVar("T")
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog") CogT = TypeVar("CogT", bound="Cog")
if TYPE_CHECKING: if TYPE_CHECKING:
P = ParamSpec('P') P = ParamSpec("P")
else: else:
P = TypeVar('P') P = TypeVar("P")
class Context(discord.abc.Messageable, Generic[BotT]): class Context(discord.abc.Messageable, Generic[BotT]):
@ -121,8 +123,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
A boolean that indicates if the command failed to be parsed, checked, A boolean that indicates if the command failed to be parsed, checked,
or invoked. or invoked.
""" """
interaction: Optional[Interaction] = None
def __init__(self, def __init__(
self,
*, *,
message: Message, message: Message,
bot: BotT, bot: BotT,
@ -151,6 +155,8 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.subcommand_passed: Optional[str] = subcommand_passed self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter self.current_parameter: Optional[inspect.Parameter] = current_parameter
self._ignored_params: List[inspect.Parameter] = []
self._typing_task: Optional[asyncio.Task[NoReturn]] = None
self._state: ConnectionState = self.message._state self._state: ConnectionState = self.message._state
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
@ -219,7 +225,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
cmd = self.command cmd = self.command
view = self.view view = self.view
if cmd is None: if cmd is None:
raise ValueError('This context is not valid.') raise ValueError("This context is not valid.")
# some state to revert to when we're done # some state to revert to when we're done
index, previous = view.index, view.previous index, previous = view.index, view.previous
@ -230,10 +236,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if restart: if restart:
to_call = cmd.root_parent or cmd to_call = cmd.root_parent or cmd
view.index = len(self.prefix or '') view.index = len(self.prefix or "")
view.previous = 0 view.previous = 0
self.invoked_parents = [] self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command self.invoked_with = view.get_word() # advance to get the root command
else: else:
to_call = cmd to_call = cmd
@ -263,7 +269,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
if self.prefix is None: if self.prefix is None:
return '' return ""
user = self.me user = self.me
# this breaks if the prefix mention is not the bot itself but I # this breaks if the prefix mention is not the bot itself but I
@ -271,7 +277,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
# for this common use case rather than waste performance for the # for this common use case rather than waste performance for the
# odd one. # odd one.
pattern = re.compile(r"<@!?%s>" % user.id) pattern = re.compile(r"<@!?%s>" % user.id)
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix) return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix)
@property @property
def cog(self) -> Optional[Cog]: def cog(self) -> Optional[Cog]:
@ -314,6 +320,13 @@ class Context(discord.abc.Messageable, Generic[BotT]):
g = self.guild g = self.guild
return g.voice_client if g else None return g.voice_client if g else None
def author_permissions(self) -> Permissions:
"""Returns the author permissions in the given channel.
.. versionadded:: 2.0
"""
return self.channel.permissions_for(self.author)
async def send_help(self, *args: Any) -> Any: async def send_help(self, *args: Any) -> Any:
"""send_help(entity=<bot>) """send_help(entity=<bot>)
@ -381,7 +394,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
await cmd.prepare_help_command(self, entity.qualified_name) await cmd.prepare_help_command(self, entity.qualified_name)
try: try:
if hasattr(entity, '__cog_commands__'): if hasattr(entity, "__cog_commands__"):
injected = wrap_callback(cmd.send_cog_help) injected = wrap_callback(cmd.send_cog_help)
return await injected(entity) return await injected(entity)
elif isinstance(entity, Group): elif isinstance(entity, Group):
@ -395,6 +408,128 @@ class Context(discord.abc.Messageable, Generic[BotT]):
except CommandError as e: except CommandError as e:
await cmd.on_help_command_error(self, e) await cmd.on_help_command_error(self, e)
@overload
async def send(
self,
content: Optional[str] = None,
return_message: Literal[False] = False,
ephemeral: bool = False,
**kwargs: Any,
) -> Optional[Union[Message, WebhookMessage]]:
...
@overload
async def send(
self,
content: Optional[str] = None,
return_message: Literal[True] = True,
ephemeral: bool = False,
**kwargs: Any,
) -> Union[Message, WebhookMessage]:
...
async def send(
self, content: Optional[str] = None, return_message: bool = True, ephemeral: bool = False, **kwargs: Any
) -> Optional[Union[Message, WebhookMessage]]:
"""
|coro|
A shortcut method to :meth:`.abc.Messageable.send` with interaction helpers.
This function takes all the parameters of :meth:`.abc.Messageable.send` plus the following:
Parameters
------------
return_message: :class:`bool`
Ignored if not in a slash command context.
If this is set to False more native interaction methods will be used.
ephemeral: :class:`bool`
Ignored if not in a slash command context.
Indicates if the message should only be visible to the user who started the interaction.
If a view is sent with an ephemeral message and it has no timeout set then the timeout
is set to 15 minutes.
Returns
--------
Optional[Union[:class:`.Message`, :class:`.WebhookMessage`]]
In a slash command context, the message that was sent if return_message is True.
In a normal context, it always returns a :class:`.Message`
"""
if self._typing_task is not None:
self._typing_task.cancel()
self._typing_task = None
if self.interaction is None or (
self.interaction.response.responded_at is not None
and discord.utils.utcnow() - self.interaction.response.responded_at >= timedelta(minutes=15)
):
return await super().send(content, **kwargs)
# Remove unsupported arguments from kwargs
kwargs.pop("nonce", None)
kwargs.pop("stickers", None)
kwargs.pop("reference", None)
kwargs.pop("delete_after", None)
kwargs.pop("mention_author", None)
if not (
return_message
or self.interaction.response.is_done()
or any(arg in kwargs for arg in ("file", "files", "allowed_mentions"))
):
send = self.interaction.response.send_message
else:
# We have to defer in order to use the followup webhook
if not self.interaction.response.is_done():
await self.interaction.response.defer(ephemeral=ephemeral)
send = self.interaction.followup.send
return await send(content, ephemeral=ephemeral, **kwargs) # type: ignore
@overload
async def reply(
self, content: Optional[str] = None, return_message: Literal[False] = False, **kwargs: Any
) -> Optional[Union[Message, WebhookMessage]]:
...
@overload
async def reply(
self, content: Optional[str] = None, return_message: Literal[True] = True, **kwargs: Any
) -> Union[Message, WebhookMessage]:
...
@discord.utils.copy_doc(Message.reply) @discord.utils.copy_doc(Message.reply)
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message: async def reply(
return await self.message.reply(content, **kwargs) self, content: Optional[str] = None, return_message: bool = True, **kwargs: Any
) -> Optional[Union[Message, WebhookMessage]]:
return await self.send(content, return_message=return_message, reference=self.message, **kwargs) # type: ignore
async def defer(self, *, ephemeral: bool = False, trigger_typing: bool = True) -> None:
"""|coro|
Defers the Slash Command interaction if ran in a slash command **or**
Loops triggering ``Bot is typing`` in the channel if run in a message command.
Parameters
------------
trigger_typing: :class:`bool`
Indicates whether to trigger typing in a message command.
ephemeral: :class:`bool`
Indicates whether the deferred message will eventually be ephemeral in a slash command.
"""
if self.interaction is None:
if self._typing_task is None and trigger_typing:
async def typing_task():
while True:
await self.trigger_typing()
await asyncio.sleep(10)
self._typing_task = self.bot.loop.create_task(typing_task())
else:
await self.interaction.response.defer(ephemeral=ephemeral)

View File

@ -52,32 +52,33 @@ if TYPE_CHECKING:
__all__ = ( __all__ = (
'Converter', "Converter",
'ObjectConverter', "ObjectConverter",
'MemberConverter', "MemberConverter",
'UserConverter', "UserConverter",
'MessageConverter', "MessageConverter",
'PartialMessageConverter', "PartialMessageConverter",
'TextChannelConverter', "TextChannelConverter",
'InviteConverter', "InviteConverter",
'GuildConverter', "GuildConverter",
'RoleConverter', "RoleConverter",
'GameConverter', "GameConverter",
'ColourConverter', "ColourConverter",
'ColorConverter', "ColorConverter",
'VoiceChannelConverter', "VoiceChannelConverter",
'StageChannelConverter', "StageChannelConverter",
'EmojiConverter', "EmojiConverter",
'PartialEmojiConverter', "PartialEmojiConverter",
'CategoryChannelConverter', "CategoryChannelConverter",
'IDConverter', "IDConverter",
'StoreChannelConverter', "StoreChannelConverter",
'ThreadConverter', "ThreadConverter",
'GuildChannelConverter', "GuildChannelConverter",
'GuildStickerConverter', "GuildStickerConverter",
'clean_content', "clean_content",
'Greedy', "Greedy",
'run_converters', "Option",
"run_converters",
) )
@ -91,10 +92,12 @@ def _get_from_guilds(bot, getter, argument):
_utils_get = discord.utils.get _utils_get = discord.utils.get
T = TypeVar('T') T = TypeVar("T")
T_co = TypeVar('T_co', covariant=True) T_co = TypeVar("T_co", covariant=True)
CT = TypeVar('CT', bound=discord.abc.GuildChannel) CT = TypeVar("CT", bound=discord.abc.GuildChannel)
TT = TypeVar('TT', bound=discord.Thread) TT = TypeVar("TT", bound=discord.Thread)
DT = TypeVar("DT", bound=str)
@runtime_checkable @runtime_checkable
@ -132,10 +135,10 @@ class Converter(Protocol[T_co]):
:exc:`.BadArgument` :exc:`.BadArgument`
The converter failed to convert the argument. The converter failed to convert the argument.
""" """
raise NotImplementedError('Derived classes need to implement this.') raise NotImplementedError("Derived classes need to implement this.")
_ID_REGEX = re.compile(r'([0-9]{15,20})$') _ID_REGEX = re.compile(r"([0-9]{15,20})$")
class IDConverter(Converter[T_co]): class IDConverter(Converter[T_co]):
@ -158,7 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.Object: async def convert(self, ctx: Context, argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument)
if match is None: if match is None:
raise ObjectNotFound(argument) raise ObjectNotFound(argument)
@ -192,8 +195,8 @@ class MemberConverter(IDConverter[discord.Member]):
async def query_member_named(self, guild, argument): async def query_member_named(self, guild, argument):
cache = guild._state.member_cache_flags.joined cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#': if len(argument) > 5 and argument[-5] == "#":
username, _, discriminator = argument.rpartition('#') username, _, discriminator = argument.rpartition("#")
members = await guild.query_members(username, limit=100, cache=cache) members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator) return discord.utils.get(members, name=username, discriminator=discriminator)
else: else:
@ -223,7 +226,7 @@ class MemberConverter(IDConverter[discord.Member]):
async def convert(self, ctx: Context, argument: str) -> discord.Member: async def convert(self, ctx: Context, argument: str) -> discord.Member:
bot = ctx.bot bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
guild = ctx.guild guild = ctx.guild
result = None result = None
user_id = None user_id = None
@ -232,13 +235,13 @@ class MemberConverter(IDConverter[discord.Member]):
if guild: if guild:
result = guild.get_member_named(argument) result = guild.get_member_named(argument)
else: else:
result = _get_from_guilds(bot, 'get_member_named', argument) result = _get_from_guilds(bot, "get_member_named", argument)
else: else:
user_id = int(match.group(1)) user_id = int(match.group(1))
if guild: if guild:
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id) result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
else: else:
result = _get_from_guilds(bot, 'get_member', user_id) result = _get_from_guilds(bot, "get_member", user_id)
if result is None: if result is None:
if guild is None: if guild is None:
@ -276,7 +279,7 @@ class UserConverter(IDConverter[discord.User]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.User: async def convert(self, ctx: Context, argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
result = None result = None
state = ctx._state state = ctx._state
@ -294,12 +297,12 @@ class UserConverter(IDConverter[discord.User]):
arg = argument arg = argument
# Remove the '@' character if this is the first character from the argument # Remove the '@' character if this is the first character from the argument
if arg[0] == '@': if arg[0] == "@":
# Remove first character # Remove first character
arg = arg[1:] arg = arg[1:]
# check for discriminator if it exists, # check for discriminator if it exists,
if len(arg) > 5 and arg[-5] == '#': if len(arg) > 5 and arg[-5] == "#":
discrim = arg[-4:] discrim = arg[-4:]
name = arg[:-5] name = arg[:-5]
predicate = lambda u: u.name == name and u.discriminator == discrim predicate = lambda u: u.name == name and u.discriminator == discrim
@ -330,22 +333,22 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod @staticmethod
def _get_id_matches(ctx, argument): def _get_id_matches(ctx, argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$') id_regex = re.compile(r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$")
link_regex = re.compile( link_regex = re.compile(
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/' r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"
r'(?P<guild_id>[0-9]{15,20}|@me)' r"(?P<guild_id>[0-9]{15,20}|@me)"
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$' r"/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$"
) )
match = id_regex.match(argument) or link_regex.match(argument) match = id_regex.match(argument) or link_regex.match(argument)
if not match: if not match:
raise MessageNotFound(argument) raise MessageNotFound(argument)
data = match.groupdict() data = match.groupdict()
channel_id = discord.utils._get_as_snowflake(data, 'channel_id') channel_id = discord.utils._get_as_snowflake(data, "channel_id")
message_id = int(data['message_id']) message_id = int(data["message_id"])
guild_id = data.get('guild_id') guild_id = data.get("guild_id")
if guild_id is None: if guild_id is None:
guild_id = ctx.guild and ctx.guild.id guild_id = ctx.guild and ctx.guild.id
elif guild_id == '@me': elif guild_id == "@me":
guild_id = None guild_id = None
else: else:
guild_id = int(guild_id) guild_id = int(guild_id)
@ -417,13 +420,13 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel) return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel)
@staticmethod @staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT: def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
bot = ctx.bot bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
result = None result = None
guild = ctx.guild guild = ctx.guild
@ -443,7 +446,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
if guild: if guild:
result = guild.get_channel(channel_id) result = guild.get_channel(channel_id)
else: else:
result = _get_from_guilds(bot, 'get_channel', channel_id) result = _get_from_guilds(bot, "get_channel", channel_id)
if not isinstance(result, type): if not isinstance(result, type):
raise ChannelNotFound(argument) raise ChannelNotFound(argument)
@ -454,7 +457,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
result = None result = None
guild = ctx.guild guild = ctx.guild
@ -491,7 +494,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel) return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel)
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
@ -511,7 +514,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel) return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel)
class StageChannelConverter(IDConverter[discord.StageChannel]): class StageChannelConverter(IDConverter[discord.StageChannel]):
@ -530,7 +533,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel) return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel)
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
@ -550,7 +553,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel) return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel)
class StoreChannelConverter(IDConverter[discord.StoreChannel]): class StoreChannelConverter(IDConverter[discord.StoreChannel]):
@ -569,7 +572,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel)
class ThreadConverter(IDConverter[discord.Thread]): class ThreadConverter(IDConverter[discord.Thread]):
@ -583,11 +586,11 @@ class ThreadConverter(IDConverter[discord.Thread]):
2. Lookup by mention. 2. Lookup by mention.
3. Lookup by name. 3. Lookup by name.
.. versionadded: 2.0 .. versionadded:: 2.0
""" """
async def convert(self, ctx: Context, argument: str) -> discord.Thread: async def convert(self, ctx: Context, argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread)
class ColourConverter(Converter[discord.Colour]): class ColourConverter(Converter[discord.Colour]):
@ -616,10 +619,10 @@ class ColourConverter(Converter[discord.Colour]):
Added support for ``rgb`` function and 3-digit hex shortcuts Added support for ``rgb`` function and 3-digit hex shortcuts
""" """
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)') RGB_REGEX = re.compile(r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)")
def parse_hex_number(self, argument): def parse_hex_number(self, argument):
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument
try: try:
value = int(arg, base=16) value = int(arg, base=16)
if not (0 <= value <= 0xFFFFFF): if not (0 <= value <= 0xFFFFFF):
@ -630,7 +633,7 @@ class ColourConverter(Converter[discord.Colour]):
return discord.Color(value=value) return discord.Color(value=value)
def parse_rgb_number(self, argument, number): def parse_rgb_number(self, argument, number):
if number[-1] == '%': if number[-1] == "%":
value = int(number[:-1]) value = int(number[:-1])
if not (0 <= value <= 100): if not (0 <= value <= 100):
raise BadColourArgument(argument) raise BadColourArgument(argument)
@ -646,29 +649,29 @@ class ColourConverter(Converter[discord.Colour]):
if match is None: if match is None:
raise BadColourArgument(argument) raise BadColourArgument(argument)
red = self.parse_rgb_number(argument, match.group('r')) red = self.parse_rgb_number(argument, match.group("r"))
green = self.parse_rgb_number(argument, match.group('g')) green = self.parse_rgb_number(argument, match.group("g"))
blue = self.parse_rgb_number(argument, match.group('b')) blue = self.parse_rgb_number(argument, match.group("b"))
return discord.Color.from_rgb(red, green, blue) return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context, argument: str) -> discord.Colour: async def convert(self, ctx: Context, argument: str) -> discord.Colour:
if argument[0] == '#': if argument[0] == "#":
return self.parse_hex_number(argument[1:]) return self.parse_hex_number(argument[1:])
if argument[0:2] == '0x': if argument[0:2] == "0x":
rest = argument[2:] rest = argument[2:]
# Legacy backwards compatible syntax # Legacy backwards compatible syntax
if rest.startswith('#'): if rest.startswith("#"):
return self.parse_hex_number(rest[1:]) return self.parse_hex_number(rest[1:])
return self.parse_hex_number(rest) return self.parse_hex_number(rest)
arg = argument.lower() arg = argument.lower()
if arg[0:3] == 'rgb': if arg[0:3] == "rgb":
return self.parse_rgb(arg) return self.parse_rgb(arg)
arg = arg.replace(' ', '_') arg = arg.replace(" ", "_")
method = getattr(discord.Colour, arg, None) method = getattr(discord.Colour, arg, None)
if arg.startswith('from_') or method is None or not inspect.ismethod(method): if arg.startswith("from_") or method is None or not inspect.ismethod(method):
raise BadColourArgument(arg) raise BadColourArgument(arg)
return method() return method()
@ -697,7 +700,7 @@ class RoleConverter(IDConverter[discord.Role]):
if not guild: if not guild:
raise NoPrivateMessage() raise NoPrivateMessage()
match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument)
if match: if match:
result = guild.get_role(int(match.group(1))) result = guild.get_role(int(match.group(1)))
else: else:
@ -776,7 +779,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.Emoji: async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument) match = self._get_id_match(argument) or re.match(r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument)
result = None result = None
bot = ctx.bot bot = ctx.bot
guild = ctx.guild guild = ctx.guild
@ -810,7 +813,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji: async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument) match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument)
if match: if match:
emoji_animated = bool(match.group(1)) emoji_animated = bool(match.group(1))
@ -903,37 +906,37 @@ class clean_content(Converter[str]):
def resolve_member(id: int) -> str: def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user' return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str: def resolve_role(id: int) -> str:
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
return f'@{r.name}' if r else '@deleted-role' return f"@{r.name}" if r else "@deleted-role"
else: else:
def resolve_member(id: int) -> str: def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id) m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
return f'@{m.name}' if m else '@deleted-user' return f"@{m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str: def resolve_role(id: int) -> str:
return '@deleted-role' return "@deleted-role"
if self.fix_channel_mentions and ctx.guild: if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id: int) -> str: def resolve_channel(id: int) -> str:
c = ctx.guild.get_channel(id) c = ctx.guild.get_channel(id)
return f'#{c.name}' if c else '#deleted-channel' return f"#{c.name}" if c else "#deleted-channel"
else: else:
def resolve_channel(id: int) -> str: def resolve_channel(id: int) -> str:
return f'<#{id}>' return f"<#{id}>"
transforms = { transforms = {
'@': resolve_member, "@": resolve_member,
'@!': resolve_member, "@!": resolve_member,
'#': resolve_channel, "#": resolve_channel,
'@&': resolve_role, "@&": resolve_role,
} }
def repl(match: re.Match) -> str: def repl(match: re.Match) -> str:
@ -942,7 +945,7 @@ class clean_content(Converter[str]):
transformed = transforms[type](id) transformed = transforms[type](id)
return transformed return transformed
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument)
if self.escape_markdown: if self.escape_markdown:
result = discord.utils.escape_markdown(result) result = discord.utils.escape_markdown(result)
elif self.remove_markdown: elif self.remove_markdown:
@ -974,42 +977,89 @@ class Greedy(List[T]):
For more information, check :ref:`ext_commands_special_converters`. For more information, check :ref:`ext_commands_special_converters`.
""" """
__slots__ = ('converter',) __slots__ = ("converter",)
def __init__(self, *, converter: T): def __init__(self, *, converter: T):
self.converter = converter self.converter = converter
def __repr__(self): def __repr__(self):
converter = getattr(self.converter, '__name__', repr(self.converter)) converter = getattr(self.converter, "__name__", repr(self.converter))
return f'Greedy[{converter}]' return f"Greedy[{converter}]"
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]: def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
if not isinstance(params, tuple): if not isinstance(params, tuple):
params = (params,) params = (params,)
if len(params) != 1: if len(params) != 1:
raise TypeError('Greedy[...] only takes a single argument') raise TypeError("Greedy[...] only takes a single argument")
converter = params[0] converter = params[0]
origin = getattr(converter, '__origin__', None) origin = getattr(converter, "__origin__", None)
args = getattr(converter, '__args__', ()) args = getattr(converter, "__args__", ())
if not (callable(converter) or isinstance(converter, Converter) or origin is not None): if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
raise TypeError('Greedy[...] expects a type or a Converter instance.') raise TypeError("Greedy[...] expects a type or a Converter instance.")
if converter in (str, type(None)) or origin is Greedy: if converter in (str, type(None)) or origin is Greedy:
raise TypeError(f'Greedy[{converter.__name__}] is invalid.') raise TypeError(f"Greedy[{converter.__name__}] is invalid.")
if origin is Union and type(None) in args: if origin is Union and type(None) in args:
raise TypeError(f'Greedy[{converter!r}] is invalid.') raise TypeError(f"Greedy[{converter!r}] is invalid.")
return cls(converter=converter) return cls(converter=converter)
class Option(Generic[T, DT]): # type: ignore
"""A special 'converter' to apply a description to slash command options.
For example in the following code:
.. code-block:: python3
@bot.command()
async def ban(ctx,
member: discord.Member, *,
reason: str = commands.Option('no reason', description='the reason to ban this member')
):
await member.ban(reason=reason)
The description would be ``the reason to ban this member`` and the default would be ``no reason``
.. versionadded:: 2.0
Attributes
------------
default: Optional[Any]
The default for this option, overwrites Option during parsing.
description: :class:`str`
The description for this option, is unpacked to :attr:`.Command.option_descriptions`
name: :class:`str`
The name of the option. This defaults to the parameter name.
"""
description: DT
default: Union[T, inspect._empty]
__slots__ = (
"default",
"description",
"name",
)
def __init__(
self, default: T = inspect.Parameter.empty, *, description: DT, name: str = discord.utils.MISSING
) -> None:
self.description = description
self.default = default
self.name: str = name
Option: Any
def _convert_to_bool(argument: str) -> bool: def _convert_to_bool(argument: str) -> bool:
lowered = argument.lower() lowered = argument.lower()
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): if lowered in ("yes", "y", "true", "t", "1", "enable", "on"):
return True return True
elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'): elif lowered in ("no", "n", "false", "f", "0", "disable", "off"):
return False return False
else: else:
raise BadBoolArgument(lowered) raise BadBoolArgument(lowered)
@ -1065,7 +1115,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError: except AttributeError:
pass pass
else: else:
if module is not None and (module.startswith('discord.') and not module.endswith('converter')): if module is not None and (module.startswith("discord.") and not module.endswith("converter")):
converter = CONVERTER_MAPPING.get(converter, converter) converter = CONVERTER_MAPPING.get(converter, converter)
try: try:
@ -1124,7 +1174,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect.
Any Any
The resulting conversion. The resulting conversion.
""" """
origin = getattr(converter, '__origin__', None) origin = getattr(converter, "__origin__", None)
if origin is Union: if origin is Union:
errors = [] errors = []

View File

@ -38,24 +38,25 @@ if TYPE_CHECKING:
from ...message import Message from ...message import Message
__all__ = ( __all__ = (
'BucketType', "BucketType",
'Cooldown', "Cooldown",
'CooldownMapping', "CooldownMapping",
'DynamicCooldownMapping', "DynamicCooldownMapping",
'MaxConcurrency', "MaxConcurrency",
) )
C = TypeVar('C', bound='CooldownMapping') C = TypeVar("C", bound="CooldownMapping")
MC = TypeVar('MC', bound='MaxConcurrency') MC = TypeVar("MC", bound="MaxConcurrency")
class BucketType(Enum): class BucketType(Enum):
default = 0 default = 0
user = 1 user = 1
guild = 2 guild = 2
channel = 3 channel = 3
member = 4 member = 4
category = 5 category = 5
role = 6 role = 6
def get_key(self, msg: Message) -> Any: def get_key(self, msg: Message) -> Any:
if self is BucketType.user: if self is BucketType.user:
@ -90,7 +91,7 @@ class Cooldown:
The length of the cooldown period in seconds. The length of the cooldown period in seconds.
""" """
__slots__ = ('rate', 'per', '_window', '_tokens', '_last') __slots__ = ("rate", "per", "_window", "_tokens", "_last")
def __init__(self, rate: float, per: float) -> None: def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate) self.rate: int = int(rate)
@ -190,7 +191,8 @@ class Cooldown:
return Cooldown(self.rate, self.per) return Cooldown(self.rate, self.per)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>' return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
class CooldownMapping: class CooldownMapping:
def __init__( def __init__(
@ -199,7 +201,7 @@ class CooldownMapping:
type: Callable[[Message], Any], type: Callable[[Message], Any],
) -> None: ) -> None:
if not callable(type): if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable') raise TypeError("Cooldown type must be a BucketType or callable")
self._cache: Dict[Any, Cooldown] = {} self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original self._cooldown: Optional[Cooldown] = original
@ -256,13 +258,9 @@ class CooldownMapping:
bucket = self.get_bucket(message, current) bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current) return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__( class DynamicCooldownMapping(CooldownMapping):
self, def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None:
factory: Callable[[Message], Cooldown],
type: Callable[[Message], Any]
) -> None:
super().__init__(None, type) super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory self._factory: Callable[[Message], Cooldown] = factory
@ -278,6 +276,7 @@ class DynamicCooldownMapping(CooldownMapping):
def create_bucket(self, message: Message) -> Cooldown: def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message) return self._factory(message)
class _Semaphore: class _Semaphore:
"""This class is a version of a semaphore. """This class is a version of a semaphore.
@ -291,7 +290,7 @@ class _Semaphore:
overkill for what is basically a counter. overkill for what is basically a counter.
""" """
__slots__ = ('value', 'loop', '_waiters') __slots__ = ("value", "loop", "_waiters")
def __init__(self, number: int) -> None: def __init__(self, number: int) -> None:
self.value: int = number self.value: int = number
@ -299,7 +298,7 @@ class _Semaphore:
self._waiters: Deque[asyncio.Future] = deque() self._waiters: Deque[asyncio.Future] = deque()
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>' return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
def locked(self) -> bool: def locked(self) -> bool:
return self.value == 0 return self.value == 0
@ -337,8 +336,9 @@ class _Semaphore:
self.value += 1 self.value += 1
self.wake_up() self.wake_up()
class MaxConcurrency: class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping') __slots__ = ("number", "per", "wait", "_mapping")
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {} self._mapping: Dict[Any, _Semaphore] = {}
@ -347,16 +347,16 @@ class MaxConcurrency:
self.wait: bool = wait self.wait: bool = wait
if number <= 0: if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1') raise ValueError("max_concurrency 'number' cannot be less than 1")
if not isinstance(per, BucketType): if not isinstance(per, BucketType):
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}') raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}")
def copy(self: MC) -> MC: def copy(self: MC) -> MC:
return self.__class__(self.number, per=self.per, wait=self.wait) return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>' return f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
def get_key(self, message: Message) -> Any: def get_key(self, message: Message) -> Any:
return self.per.get_key(message) return self.per.get_key(message)

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,7 @@ if TYPE_CHECKING:
from .converter import Converter from .converter import Converter
from .context import Context from .context import Context
from .core import Command
from .cooldowns import Cooldown, BucketType from .cooldowns import Cooldown, BucketType
from .flags import Flag from .flags import Flag
from discord.abc import GuildChannel from discord.abc import GuildChannel
@ -41,65 +42,67 @@ if TYPE_CHECKING:
__all__ = ( __all__ = (
'CommandError', "CommandError",
'MissingRequiredArgument', "MissingRequiredArgument",
'BadArgument', "BadArgument",
'PrivateMessageOnly', "PrivateMessageOnly",
'NoPrivateMessage', "NoPrivateMessage",
'CheckFailure', "CheckFailure",
'CheckAnyFailure', "CheckAnyFailure",
'CommandNotFound', "CommandNotFound",
'DisabledCommand', "DisabledCommand",
'CommandInvokeError', "CommandInvokeError",
'TooManyArguments', "TooManyArguments",
'UserInputError', "UserInputError",
'CommandOnCooldown', "CommandOnCooldown",
'MaxConcurrencyReached', "MaxConcurrencyReached",
'NotOwner', "NotOwner",
'MessageNotFound', "MessageNotFound",
'ObjectNotFound', "ObjectNotFound",
'MemberNotFound', "MemberNotFound",
'GuildNotFound', "GuildNotFound",
'UserNotFound', "UserNotFound",
'ChannelNotFound', "ChannelNotFound",
'ThreadNotFound', "ThreadNotFound",
'ChannelNotReadable', "ChannelNotReadable",
'BadColourArgument', "BadColourArgument",
'BadColorArgument', "BadColorArgument",
'RoleNotFound', "RoleNotFound",
'BadInviteArgument', "BadInviteArgument",
'EmojiNotFound', "EmojiNotFound",
'GuildStickerNotFound', "GuildStickerNotFound",
'PartialEmojiConversionFailure', "PartialEmojiConversionFailure",
'BadBoolArgument', "BadBoolArgument",
'MissingRole', "MissingRole",
'BotMissingRole', "BotMissingRole",
'MissingAnyRole', "MissingAnyRole",
'BotMissingAnyRole', "BotMissingAnyRole",
'MissingPermissions', "MissingPermissions",
'BotMissingPermissions', "BotMissingPermissions",
'NSFWChannelRequired', "NSFWChannelRequired",
'ConversionError', "ConversionError",
'BadUnionArgument', "BadUnionArgument",
'BadLiteralArgument', "BadLiteralArgument",
'ArgumentParsingError', "ArgumentParsingError",
'UnexpectedQuoteError', "UnexpectedQuoteError",
'InvalidEndOfQuotedStringError', "InvalidEndOfQuotedStringError",
'ExpectedClosingQuoteError', "ExpectedClosingQuoteError",
'ExtensionError', "ExtensionError",
'ExtensionAlreadyLoaded', "ExtensionAlreadyLoaded",
'ExtensionNotLoaded', "ExtensionNotLoaded",
'NoEntryPointError', "NoEntryPointError",
'ExtensionFailed', "ExtensionFailed",
'ExtensionNotFound', "ExtensionNotFound",
'CommandRegistrationError', "CommandRegistrationError",
'FlagError', "ApplicationCommandRegistrationError",
'BadFlagArgument', "FlagError",
'MissingFlagArgument', "BadFlagArgument",
'TooManyFlags', "MissingFlagArgument",
'MissingRequiredFlag', "TooManyFlags",
"MissingRequiredFlag",
) )
class CommandError(DiscordException): class CommandError(DiscordException):
r"""The base exception type for all command related errors. r"""The base exception type for all command related errors.
@ -109,14 +112,16 @@ class CommandError(DiscordException):
in a special way as they are caught and passed into a special event in a special way as they are caught and passed into a special event
from :class:`.Bot`\, :func:`.on_command_error`. from :class:`.Bot`\, :func:`.on_command_error`.
""" """
def __init__(self, message: Optional[str] = None, *args: Any) -> None: def __init__(self, message: Optional[str] = None, *args: Any) -> None:
if message is not None: if message is not None:
# clean-up @everyone and @here mentions # clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
super().__init__(m, *args) super().__init__(m, *args)
else: else:
super().__init__(*args) super().__init__(*args)
class ConversionError(CommandError): class ConversionError(CommandError):
"""Exception raised when a Converter class raises non-CommandError. """Exception raised when a Converter class raises non-CommandError.
@ -130,18 +135,22 @@ class ConversionError(CommandError):
The original exception that was raised. You can also get this via The original exception that was raised. You can also get this via
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, converter: Converter, original: Exception) -> None: def __init__(self, converter: Converter, original: Exception) -> None:
self.converter: Converter = converter self.converter: Converter = converter
self.original: Exception = original self.original: Exception = original
class UserInputError(CommandError): class UserInputError(CommandError):
"""The base exception type for errors that involve errors """The base exception type for errors that involve errors
regarding user input. regarding user input.
This inherits from :exc:`CommandError`. This inherits from :exc:`CommandError`.
""" """
pass pass
class CommandNotFound(CommandError): class CommandNotFound(CommandError):
"""Exception raised when a command is attempted to be invoked """Exception raised when a command is attempted to be invoked
but no command under that name is found. but no command under that name is found.
@ -151,8 +160,10 @@ class CommandNotFound(CommandError):
This inherits from :exc:`CommandError`. This inherits from :exc:`CommandError`.
""" """
pass pass
class MissingRequiredArgument(UserInputError): class MissingRequiredArgument(UserInputError):
"""Exception raised when parsing a command and a parameter """Exception raised when parsing a command and a parameter
that is required is not encountered. that is required is not encountered.
@ -164,9 +175,11 @@ class MissingRequiredArgument(UserInputError):
param: :class:`inspect.Parameter` param: :class:`inspect.Parameter`
The argument that is missing. The argument that is missing.
""" """
def __init__(self, param: Parameter) -> None: def __init__(self, param: Parameter) -> None:
self.param: Parameter = param self.param: Parameter = param
super().__init__(f'{param.name} is a required argument that is missing.') super().__init__(f"{param.name} is a required argument that is missing.")
class TooManyArguments(UserInputError): class TooManyArguments(UserInputError):
"""Exception raised when the command was passed too many arguments and its """Exception raised when the command was passed too many arguments and its
@ -174,23 +187,29 @@ class TooManyArguments(UserInputError):
This inherits from :exc:`UserInputError` This inherits from :exc:`UserInputError`
""" """
pass pass
class BadArgument(UserInputError): class BadArgument(UserInputError):
"""Exception raised when a parsing or conversion failure is encountered """Exception raised when a parsing or conversion failure is encountered
on an argument to pass into a command. on an argument to pass into a command.
This inherits from :exc:`UserInputError` This inherits from :exc:`UserInputError`
""" """
pass pass
class CheckFailure(CommandError): class CheckFailure(CommandError):
"""Exception raised when the predicates in :attr:`.Command.checks` have failed. """Exception raised when the predicates in :attr:`.Command.checks` have failed.
This inherits from :exc:`CommandError` This inherits from :exc:`CommandError`
""" """
pass pass
class CheckAnyFailure(CheckFailure): class CheckAnyFailure(CheckFailure):
"""Exception raised when all predicates in :func:`check_any` fail. """Exception raised when all predicates in :func:`check_any` fail.
@ -209,7 +228,8 @@ class CheckAnyFailure(CheckFailure):
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None: def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
self.checks: List[CheckFailure] = checks self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors self.errors: List[Callable[[Context], bool]] = errors
super().__init__('You do not have permission to run this command.') super().__init__("You do not have permission to run this command.")
class PrivateMessageOnly(CheckFailure): class PrivateMessageOnly(CheckFailure):
"""Exception raised when an operation does not work outside of private """Exception raised when an operation does not work outside of private
@ -217,8 +237,10 @@ class PrivateMessageOnly(CheckFailure):
This inherits from :exc:`CheckFailure` This inherits from :exc:`CheckFailure`
""" """
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command can only be used in private messages.') super().__init__(message or "This command can only be used in private messages.")
class NoPrivateMessage(CheckFailure): class NoPrivateMessage(CheckFailure):
"""Exception raised when an operation does not work in private message """Exception raised when an operation does not work in private message
@ -228,15 +250,18 @@ class NoPrivateMessage(CheckFailure):
""" """
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
super().__init__(message or 'This command cannot be used in private messages.') super().__init__(message or "This command cannot be used in private messages.")
class NotOwner(CheckFailure): class NotOwner(CheckFailure):
"""Exception raised when the message author is not the owner of the bot. """Exception raised when the message author is not the owner of the bot.
This inherits from :exc:`CheckFailure` This inherits from :exc:`CheckFailure`
""" """
pass pass
class ObjectNotFound(BadArgument): class ObjectNotFound(BadArgument):
"""Exception raised when the argument provided did not match the format """Exception raised when the argument provided did not match the format
of an ID or a mention. of an ID or a mention.
@ -250,9 +275,11 @@ class ObjectNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The argument supplied by the caller that was not matched The argument supplied by the caller that was not matched
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'{argument!r} does not follow a valid ID or mention format.') super().__init__(f"{argument!r} does not follow a valid ID or mention format.")
class MemberNotFound(BadArgument): class MemberNotFound(BadArgument):
"""Exception raised when the member provided was not found in the bot's """Exception raised when the member provided was not found in the bot's
@ -267,10 +294,12 @@ class MemberNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The member supplied by the caller that was not found The member supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Member "{argument}" not found.') super().__init__(f'Member "{argument}" not found.')
class GuildNotFound(BadArgument): class GuildNotFound(BadArgument):
"""Exception raised when the guild provided was not found in the bot's cache. """Exception raised when the guild provided was not found in the bot's cache.
@ -283,10 +312,12 @@ class GuildNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The guild supplied by the called that was not found The guild supplied by the called that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Guild "{argument}" not found.') super().__init__(f'Guild "{argument}" not found.')
class UserNotFound(BadArgument): class UserNotFound(BadArgument):
"""Exception raised when the user provided was not found in the bot's """Exception raised when the user provided was not found in the bot's
cache. cache.
@ -300,10 +331,12 @@ class UserNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The user supplied by the caller that was not found The user supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'User "{argument}" not found.') super().__init__(f'User "{argument}" not found.')
class MessageNotFound(BadArgument): class MessageNotFound(BadArgument):
"""Exception raised when the message provided was not found in the channel. """Exception raised when the message provided was not found in the channel.
@ -316,10 +349,12 @@ class MessageNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The message supplied by the caller that was not found The message supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Message "{argument}" not found.') super().__init__(f'Message "{argument}" not found.')
class ChannelNotReadable(BadArgument): class ChannelNotReadable(BadArgument):
"""Exception raised when the bot does not have permission to read messages """Exception raised when the bot does not have permission to read messages
in the channel. in the channel.
@ -333,10 +368,12 @@ class ChannelNotReadable(BadArgument):
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel supplied by the caller that was not readable The channel supplied by the caller that was not readable
""" """
def __init__(self, argument: Union[GuildChannel, Thread]) -> None: def __init__(self, argument: Union[GuildChannel, Thread]) -> None:
self.argument: Union[GuildChannel, Thread] = argument self.argument: Union[GuildChannel, Thread] = argument
super().__init__(f"Can't read messages in {argument.mention}.") super().__init__(f"Can't read messages in {argument.mention}.")
class ChannelNotFound(BadArgument): class ChannelNotFound(BadArgument):
"""Exception raised when the bot can not find the channel. """Exception raised when the bot can not find the channel.
@ -349,10 +386,12 @@ class ChannelNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The channel supplied by the caller that was not found The channel supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Channel "{argument}" not found.') super().__init__(f'Channel "{argument}" not found.')
class ThreadNotFound(BadArgument): class ThreadNotFound(BadArgument):
"""Exception raised when the bot can not find the thread. """Exception raised when the bot can not find the thread.
@ -365,10 +404,12 @@ class ThreadNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The thread supplied by the caller that was not found The thread supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Thread "{argument}" not found.') super().__init__(f'Thread "{argument}" not found.')
class BadColourArgument(BadArgument): class BadColourArgument(BadArgument):
"""Exception raised when the colour is not valid. """Exception raised when the colour is not valid.
@ -381,12 +422,15 @@ class BadColourArgument(BadArgument):
argument: :class:`str` argument: :class:`str`
The colour supplied by the caller that was not valid The colour supplied by the caller that was not valid
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Colour "{argument}" is invalid.') super().__init__(f'Colour "{argument}" is invalid.')
BadColorArgument = BadColourArgument BadColorArgument = BadColourArgument
class RoleNotFound(BadArgument): class RoleNotFound(BadArgument):
"""Exception raised when the bot can not find the role. """Exception raised when the bot can not find the role.
@ -399,21 +443,30 @@ class RoleNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The role supplied by the caller that was not found The role supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Role "{argument}" not found.') super().__init__(f'Role "{argument}" not found.')
class BadInviteArgument(BadArgument): class BadInviteArgument(BadArgument):
"""Exception raised when the invite is invalid or expired. """Exception raised when the invite is invalid or expired.
This inherits from :exc:`BadArgument` This inherits from :exc:`BadArgument`
.. versionadded:: 1.5 .. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The invite supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Invite "{argument}" is invalid or expired.') super().__init__(f'Invite "{argument}" is invalid or expired.')
class EmojiNotFound(BadArgument): class EmojiNotFound(BadArgument):
"""Exception raised when the bot can not find the emoji. """Exception raised when the bot can not find the emoji.
@ -426,10 +479,12 @@ class EmojiNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The emoji supplied by the caller that was not found The emoji supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Emoji "{argument}" not found.') super().__init__(f'Emoji "{argument}" not found.')
class PartialEmojiConversionFailure(BadArgument): class PartialEmojiConversionFailure(BadArgument):
"""Exception raised when the emoji provided does not match the correct """Exception raised when the emoji provided does not match the correct
format. format.
@ -443,10 +498,12 @@ class PartialEmojiConversionFailure(BadArgument):
argument: :class:`str` argument: :class:`str`
The emoji supplied by the caller that did not match the regex The emoji supplied by the caller that did not match the regex
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.') super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
class GuildStickerNotFound(BadArgument): class GuildStickerNotFound(BadArgument):
"""Exception raised when the bot can not find the sticker. """Exception raised when the bot can not find the sticker.
@ -459,10 +516,12 @@ class GuildStickerNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The sticker supplied by the caller that was not found The sticker supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'Sticker "{argument}" not found.') super().__init__(f'Sticker "{argument}" not found.')
class BadBoolArgument(BadArgument): class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable. """Exception raised when a boolean argument was not convertable.
@ -475,17 +534,21 @@ class BadBoolArgument(BadArgument):
argument: :class:`str` argument: :class:`str`
The boolean argument supplied by the caller that is not in the predefined list The boolean argument supplied by the caller that is not in the predefined list
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument: str) -> None:
self.argument: str = argument self.argument: str = argument
super().__init__(f'{argument} is not a recognised boolean option') super().__init__(f"{argument} is not a recognised boolean option")
class DisabledCommand(CommandError): class DisabledCommand(CommandError):
"""Exception raised when the command being invoked is disabled. """Exception raised when the command being invoked is disabled.
This inherits from :exc:`CommandError` This inherits from :exc:`CommandError`
""" """
pass pass
class CommandInvokeError(CommandError): class CommandInvokeError(CommandError):
"""Exception raised when the command being invoked raised an exception. """Exception raised when the command being invoked raised an exception.
@ -497,9 +560,11 @@ class CommandInvokeError(CommandError):
The original exception that was raised. You can also get this via The original exception that was raised. You can also get this via
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, e: Exception) -> None: def __init__(self, e: Exception) -> None:
self.original: Exception = e self.original: Exception = e
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}') super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}")
class CommandOnCooldown(CommandError): class CommandOnCooldown(CommandError):
"""Exception raised when the command being invoked is on cooldown. """Exception raised when the command being invoked is on cooldown.
@ -516,11 +581,13 @@ class CommandOnCooldown(CommandError):
retry_after: :class:`float` retry_after: :class:`float`
The amount of seconds to wait before you can retry again. The amount of seconds to wait before you can retry again.
""" """
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None:
self.cooldown: Cooldown = cooldown self.cooldown: Cooldown = cooldown
self.retry_after: float = retry_after self.retry_after: float = retry_after
self.type: BucketType = type self.type: BucketType = type
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s') super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s")
class MaxConcurrencyReached(CommandError): class MaxConcurrencyReached(CommandError):
"""Exception raised when the command being invoked has reached its maximum concurrency. """Exception raised when the command being invoked has reached its maximum concurrency.
@ -539,10 +606,11 @@ class MaxConcurrencyReached(CommandError):
self.number: int = number self.number: int = number
self.per: BucketType = per self.per: BucketType = per
name = per.name name = per.name
suffix = 'per %s' % name if per.name != 'default' else 'globally' suffix = "per %s" % name if per.name != "default" else "globally"
plural = '%s times %s' if number > 1 else '%s time %s' plural = "%s times %s" if number > 1 else "%s time %s"
fmt = plural % (number, suffix) fmt = plural % (number, suffix)
super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.') super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.")
class MissingRole(CheckFailure): class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command. """Exception raised when the command invoker lacks a role to run a command.
@ -557,11 +625,13 @@ class MissingRole(CheckFailure):
The required role that is missing. The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`. This is the parameter passed to :func:`~.commands.has_role`.
""" """
def __init__(self, missing_role: Snowflake) -> None: def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role self.missing_role: Snowflake = missing_role
message = f'Role {missing_role!r} is required to run this command.' message = f"Role {missing_role!r} is required to run this command."
super().__init__(message) super().__init__(message)
class BotMissingRole(CheckFailure): class BotMissingRole(CheckFailure):
"""Exception raised when the bot's member lacks a role to run a command. """Exception raised when the bot's member lacks a role to run a command.
@ -575,11 +645,13 @@ class BotMissingRole(CheckFailure):
The required role that is missing. The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`. This is the parameter passed to :func:`~.commands.has_role`.
""" """
def __init__(self, missing_role: Snowflake) -> None: def __init__(self, missing_role: Snowflake) -> None:
self.missing_role: Snowflake = missing_role self.missing_role: Snowflake = missing_role
message = f'Bot requires the role {missing_role!r} to run this command' message = f"Bot requires the role {missing_role!r} to run this command"
super().__init__(message) super().__init__(message)
class MissingAnyRole(CheckFailure): class MissingAnyRole(CheckFailure):
"""Exception raised when the command invoker lacks any of """Exception raised when the command invoker lacks any of
the roles specified to run a command. the roles specified to run a command.
@ -594,15 +666,16 @@ class MissingAnyRole(CheckFailure):
The roles that the invoker is missing. The roles that the invoker is missing.
These are the parameters passed to :func:`~.commands.has_any_role`. These are the parameters passed to :func:`~.commands.has_any_role`.
""" """
def __init__(self, missing_roles: SnowflakeList) -> None: def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles] missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2: if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1]) fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else: else:
fmt = ' or '.join(missing) fmt = " or ".join(missing)
message = f"You are missing at least one of the required roles: {fmt}" message = f"You are missing at least one of the required roles: {fmt}"
super().__init__(message) super().__init__(message)
@ -623,19 +696,21 @@ class BotMissingAnyRole(CheckFailure):
These are the parameters passed to :func:`~.commands.has_any_role`. These are the parameters passed to :func:`~.commands.has_any_role`.
""" """
def __init__(self, missing_roles: SnowflakeList) -> None: def __init__(self, missing_roles: SnowflakeList) -> None:
self.missing_roles: SnowflakeList = missing_roles self.missing_roles: SnowflakeList = missing_roles
missing = [f"'{role}'" for role in missing_roles] missing = [f"'{role}'" for role in missing_roles]
if len(missing) > 2: if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1]) fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1])
else: else:
fmt = ' or '.join(missing) fmt = " or ".join(missing)
message = f"Bot is missing at least one of the required roles: {fmt}" message = f"Bot is missing at least one of the required roles: {fmt}"
super().__init__(message) super().__init__(message)
class NSFWChannelRequired(CheckFailure): class NSFWChannelRequired(CheckFailure):
"""Exception raised when a channel does not have the required NSFW setting. """Exception raised when a channel does not have the required NSFW setting.
@ -648,10 +723,12 @@ class NSFWChannelRequired(CheckFailure):
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`]
The channel that does not have NSFW enabled. The channel that does not have NSFW enabled.
""" """
def __init__(self, channel: Union[GuildChannel, Thread]) -> None: def __init__(self, channel: Union[GuildChannel, Thread]) -> None:
self.channel: Union[GuildChannel, Thread] = channel self.channel: Union[GuildChannel, Thread] = channel
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.") super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
class MissingPermissions(CheckFailure): class MissingPermissions(CheckFailure):
"""Exception raised when the command invoker lacks permissions to run a """Exception raised when the command invoker lacks permissions to run a
command. command.
@ -663,18 +740,20 @@ class MissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`] missing_permissions: List[:class:`str`]
The required permissions that are missing. The required permissions that are missing.
""" """
def __init__(self, missing_permissions: List[str], *args: Any) -> None: def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions]
if len(missing) > 2: if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else: else:
fmt = ' and '.join(missing) fmt = " and ".join(missing)
message = f'You are missing {fmt} permission(s) to run this command.' message = f"You are missing {fmt} permission(s) to run this command."
super().__init__(message, *args) super().__init__(message, *args)
class BotMissingPermissions(CheckFailure): class BotMissingPermissions(CheckFailure):
"""Exception raised when the bot's member lacks permissions to run a """Exception raised when the bot's member lacks permissions to run a
command. command.
@ -686,18 +765,20 @@ class BotMissingPermissions(CheckFailure):
missing_permissions: List[:class:`str`] missing_permissions: List[:class:`str`]
The required permissions that are missing. The required permissions that are missing.
""" """
def __init__(self, missing_permissions: List[str], *args: Any) -> None: def __init__(self, missing_permissions: List[str], *args: Any) -> None:
self.missing_permissions: List[str] = missing_permissions self.missing_permissions: List[str] = missing_permissions
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions]
if len(missing) > 2: if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1])
else: else:
fmt = ' and '.join(missing) fmt = " and ".join(missing)
message = f'Bot requires {fmt} permission(s) to run this command.' message = f"Bot requires {fmt} permission(s) to run this command."
super().__init__(message, *args) super().__init__(message, *args)
class BadUnionArgument(UserInputError): class BadUnionArgument(UserInputError):
"""Exception raised when a :data:`typing.Union` converter fails for all """Exception raised when a :data:`typing.Union` converter fails for all
its associated types. its associated types.
@ -713,6 +794,7 @@ class BadUnionArgument(UserInputError):
errors: List[:class:`CommandError`] errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion. A list of errors that were caught from failing the conversion.
""" """
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param self.param: Parameter = param
self.converters: Tuple[Type, ...] = converters self.converters: Tuple[Type, ...] = converters
@ -722,18 +804,19 @@ class BadUnionArgument(UserInputError):
try: try:
return x.__name__ return x.__name__
except AttributeError: except AttributeError:
if hasattr(x, '__origin__'): if hasattr(x, "__origin__"):
return repr(x) return repr(x)
return x.__class__.__name__ return x.__class__.__name__
to_string = [_get_name(x) for x in converters] to_string = [_get_name(x) for x in converters]
if len(to_string) > 2: if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1]) fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else: else:
fmt = ' or '.join(to_string) fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into {fmt}.') super().__init__(f'Could not convert "{param.name}" into {fmt}.')
class BadLiteralArgument(UserInputError): class BadLiteralArgument(UserInputError):
"""Exception raised when a :data:`typing.Literal` converter fails for all """Exception raised when a :data:`typing.Literal` converter fails for all
its associated values. its associated values.
@ -751,6 +834,7 @@ class BadLiteralArgument(UserInputError):
errors: List[:class:`CommandError`] errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion. A list of errors that were caught from failing the conversion.
""" """
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None: def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param self.param: Parameter = param
self.literals: Tuple[Any, ...] = literals self.literals: Tuple[Any, ...] = literals
@ -758,12 +842,13 @@ class BadLiteralArgument(UserInputError):
to_string = [repr(l) for l in literals] to_string = [repr(l) for l in literals]
if len(to_string) > 2: if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1]) fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1])
else: else:
fmt = ' or '.join(to_string) fmt = " or ".join(to_string)
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.') super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
class ArgumentParsingError(UserInputError): class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input. """An exception raised when the parser fails to parse a user's input.
@ -772,8 +857,10 @@ class ArgumentParsingError(UserInputError):
There are child classes that implement more granular parsing errors for There are child classes that implement more granular parsing errors for
i18n purposes. i18n purposes.
""" """
pass pass
class UnexpectedQuoteError(ArgumentParsingError): class UnexpectedQuoteError(ArgumentParsingError):
"""An exception raised when the parser encounters a quote mark inside a non-quoted string. """An exception raised when the parser encounters a quote mark inside a non-quoted string.
@ -784,9 +871,11 @@ class UnexpectedQuoteError(ArgumentParsingError):
quote: :class:`str` quote: :class:`str`
The quote mark that was found inside the non-quoted string. The quote mark that was found inside the non-quoted string.
""" """
def __init__(self, quote: str) -> None: def __init__(self, quote: str) -> None:
self.quote: str = quote self.quote: str = quote
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string') super().__init__(f"Unexpected quote mark, {quote!r}, in non-quoted string")
class InvalidEndOfQuotedStringError(ArgumentParsingError): class InvalidEndOfQuotedStringError(ArgumentParsingError):
"""An exception raised when a space is expected after the closing quote in a string """An exception raised when a space is expected after the closing quote in a string
@ -799,9 +888,11 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
char: :class:`str` char: :class:`str`
The character found instead of the expected string. The character found instead of the expected string.
""" """
def __init__(self, char: str) -> None: def __init__(self, char: str) -> None:
self.char: str = char self.char: str = char
super().__init__(f'Expected space after closing quotation but received {char!r}') super().__init__(f"Expected space after closing quotation but received {char!r}")
class ExpectedClosingQuoteError(ArgumentParsingError): class ExpectedClosingQuoteError(ArgumentParsingError):
"""An exception raised when a quote character is expected but not found. """An exception raised when a quote character is expected but not found.
@ -816,7 +907,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
def __init__(self, close_quote: str) -> None: def __init__(self, close_quote: str) -> None:
self.close_quote: str = close_quote self.close_quote: str = close_quote
super().__init__(f'Expected closing {close_quote}.') super().__init__(f"Expected closing {close_quote}.")
class ExtensionError(DiscordException): class ExtensionError(DiscordException):
"""Base exception for extension related errors. """Base exception for extension related errors.
@ -828,37 +920,45 @@ class ExtensionError(DiscordException):
name: :class:`str` name: :class:`str`
The extension that had an error. The extension that had an error.
""" """
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None: def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None:
self.name: str = name self.name: str = name
message = message or f'Extension {name!r} had an error.' message = message or f"Extension {name!r} had an error."
# clean-up @everyone and @here mentions # clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere")
super().__init__(m, *args) super().__init__(m, *args)
class ExtensionAlreadyLoaded(ExtensionError): class ExtensionAlreadyLoaded(ExtensionError):
"""An exception raised when an extension has already been loaded. """An exception raised when an extension has already been loaded.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
""" """
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} is already loaded.', name=name) super().__init__(f"Extension {name!r} is already loaded.", name=name)
class ExtensionNotLoaded(ExtensionError): class ExtensionNotLoaded(ExtensionError):
"""An exception raised when an extension was not loaded. """An exception raised when an extension was not loaded.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
""" """
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
super().__init__(f'Extension {name!r} has not been loaded.', name=name) super().__init__(f"Extension {name!r} has not been loaded.", name=name)
class NoEntryPointError(ExtensionError): class NoEntryPointError(ExtensionError):
"""An exception raised when an extension does not have a ``setup`` entry point function. """An exception raised when an extension does not have a ``setup`` entry point function.
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
""" """
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name) super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
class ExtensionFailed(ExtensionError): class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point. """An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
@ -872,11 +972,13 @@ class ExtensionFailed(ExtensionError):
The original exception that was raised. You can also get this via The original exception that was raised. You can also get this via
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, name: str, original: Exception) -> None: def __init__(self, name: str, original: Exception) -> None:
self.original: Exception = original self.original: Exception = original
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}' msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}"
super().__init__(msg, name=name) super().__init__(msg, name=name)
class ExtensionNotFound(ExtensionError): class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension is not found. """An exception raised when an extension is not found.
@ -890,10 +992,12 @@ class ExtensionNotFound(ExtensionError):
name: :class:`str` name: :class:`str`
The extension that had the error. The extension that had the error.
""" """
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
msg = f'Extension {name!r} could not be loaded.' msg = f"Extension {name!r} could not be loaded."
super().__init__(msg, name=name) super().__init__(msg, name=name)
class CommandRegistrationError(ClientException): class CommandRegistrationError(ClientException):
"""An exception raised when the command can't be added """An exception raised when the command can't be added
because the name is already taken by a different command. because the name is already taken by a different command.
@ -909,11 +1013,32 @@ class CommandRegistrationError(ClientException):
alias_conflict: :class:`bool` alias_conflict: :class:`bool`
Whether the name that conflicts is an alias of the command we try to add. Whether the name that conflicts is an alias of the command we try to add.
""" """
def __init__(self, name: str, *, alias_conflict: bool = False) -> None: def __init__(self, name: str, *, alias_conflict: bool = False) -> None:
self.name: str = name self.name: str = name
self.alias_conflict: bool = alias_conflict self.alias_conflict: bool = alias_conflict
type_ = 'alias' if alias_conflict else 'command' type_ = "alias" if alias_conflict else "command"
super().__init__(f'The {type_} {name} is already an existing command or alias.') super().__init__(f"The {type_} {name} is already an existing command or alias.")
class ApplicationCommandRegistrationError(ClientException):
"""An exception raised when a command cannot be converted to an
application command.
This inherits from :exc:`discord.ClientException`
.. versionadded:: 2.0
Attributes
----------
command: :class:`Command`
The command that failed to be converted.
"""
def __init__(self, command: Command, msg: str = None) -> None:
self.command = command
super().__init__(msg or f"{command.qualified_name} failed to converted to an application command.")
class FlagError(BadArgument): class FlagError(BadArgument):
"""The base exception type for all flag parsing related errors. """The base exception type for all flag parsing related errors.
@ -922,8 +1047,10 @@ class FlagError(BadArgument):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
pass pass
class TooManyFlags(FlagError): class TooManyFlags(FlagError):
"""An exception raised when a flag has received too many values. """An exception raised when a flag has received too many values.
@ -938,10 +1065,12 @@ class TooManyFlags(FlagError):
values: List[:class:`str`] values: List[:class:`str`]
The values that were passed. The values that were passed.
""" """
def __init__(self, flag: Flag, values: List[str]) -> None: def __init__(self, flag: Flag, values: List[str]) -> None:
self.flag: Flag = flag self.flag: Flag = flag
self.values: List[str] = values self.values: List[str] = values
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.') super().__init__(f"Too many flag values, expected {flag.max_args} but received {len(values)}.")
class BadFlagArgument(FlagError): class BadFlagArgument(FlagError):
"""An exception raised when a flag failed to convert a value. """An exception raised when a flag failed to convert a value.
@ -955,6 +1084,7 @@ class BadFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag` flag: :class:`~discord.ext.commands.Flag`
The flag that failed to convert. The flag that failed to convert.
""" """
def __init__(self, flag: Flag) -> None: def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag self.flag: Flag = flag
try: try:
@ -962,7 +1092,8 @@ class BadFlagArgument(FlagError):
except AttributeError: except AttributeError:
name = flag.annotation.__class__.__name__ name = flag.annotation.__class__.__name__
super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}') super().__init__(f"Could not convert to {name!r} for flag {flag.name!r}")
class MissingRequiredFlag(FlagError): class MissingRequiredFlag(FlagError):
"""An exception raised when a required flag was not given. """An exception raised when a required flag was not given.
@ -976,9 +1107,11 @@ class MissingRequiredFlag(FlagError):
flag: :class:`~discord.ext.commands.Flag` flag: :class:`~discord.ext.commands.Flag`
The required flag that was not found. The required flag that was not found.
""" """
def __init__(self, flag: Flag) -> None: def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} is required and missing') super().__init__(f"Flag {flag.name!r} is required and missing")
class MissingFlagArgument(FlagError): class MissingFlagArgument(FlagError):
"""An exception raised when a flag did not get a value. """An exception raised when a flag did not get a value.
@ -992,6 +1125,7 @@ class MissingFlagArgument(FlagError):
flag: :class:`~discord.ext.commands.Flag` flag: :class:`~discord.ext.commands.Flag`
The flag that did not get a value. The flag that did not get a value.
""" """
def __init__(self, flag: Flag) -> None: def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} does not have an argument') super().__init__(f"Flag {flag.name!r} does not have an argument")

View File

@ -59,9 +59,9 @@ import sys
import re import re
__all__ = ( __all__ = (
'Flag', "Flag",
'flag', "flag",
'FlagConverter', "FlagConverter",
) )
@ -81,6 +81,8 @@ class Flag:
------------ ------------
name: :class:`str` name: :class:`str`
The name of the flag. The name of the flag.
description: :class:`str`
The description of the flag.
aliases: List[:class:`str`] aliases: List[:class:`str`]
The aliases of the flag name. The aliases of the flag name.
attribute: :class:`str` attribute: :class:`str`
@ -97,6 +99,7 @@ class Flag:
""" """
name: str = MISSING name: str = MISSING
description: str = MISSING
aliases: List[str] = field(default_factory=list) aliases: List[str] = field(default_factory=list)
attribute: str = MISSING attribute: str = MISSING
annotation: Any = MISSING annotation: Any = MISSING
@ -117,6 +120,7 @@ class Flag:
def flag( def flag(
*, *,
name: str = MISSING, name: str = MISSING,
description: str = MISSING,
aliases: List[str] = MISSING, aliases: List[str] = MISSING,
default: Any = MISSING, default: Any = MISSING,
max_args: int = MISSING, max_args: int = MISSING,
@ -129,6 +133,8 @@ def flag(
------------ ------------
name: :class:`str` name: :class:`str`
The flag name. If not given, defaults to the attribute name. The flag name. If not given, defaults to the attribute name.
description: :class:`str`
Description of the flag for the slash commands options. The default value is `'no description'`.
aliases: List[:class:`str`] aliases: List[:class:`str`]
Aliases to the flag name. If not given no aliases are set. Aliases to the flag name. If not given no aliases are set.
default: Any default: Any
@ -143,25 +149,27 @@ def flag(
Whether multiple given values overrides the previous value. The default Whether multiple given values overrides the previous value. The default
value depends on the annotation given. value depends on the annotation given.
""" """
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override) return Flag(
name=name, description=description, aliases=aliases, default=default, max_args=max_args, override=override
)
def validate_flag_name(name: str, forbidden: Set[str]): def validate_flag_name(name: str, forbidden: Set[str]):
if not name: if not name:
raise ValueError('flag names should not be empty') raise ValueError("flag names should not be empty")
for ch in name: for ch in name:
if ch.isspace(): if ch.isspace():
raise ValueError(f'flag name {name!r} cannot have spaces') raise ValueError(f"flag name {name!r} cannot have spaces")
if ch == '\\': if ch == "\\":
raise ValueError(f'flag name {name!r} cannot have backslashes') raise ValueError(f"flag name {name!r} cannot have backslashes")
if ch in forbidden: if ch in forbidden:
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them') raise ValueError(f"flag name {name!r} cannot have any of {forbidden!r} within them")
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
annotations = namespace.get('__annotations__', {}) annotations = namespace.get("__annotations__", {})
case_insensitive = namespace['__commands_flag_case_insensitive__'] case_insensitive = namespace["__commands_flag_case_insensitive__"]
flags: Dict[str, Flag] = {} flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {} cache: Dict[str, Any] = {}
names: Set[str] = set() names: Set[str] = set()
@ -178,7 +186,11 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible(): if (
flag.default is MISSING
and hasattr(annotation, "__commands_is_flag__")
and annotation._can_be_constructible()
):
flag.default = annotation._construct_default flag.default = annotation._construct_default
if flag.aliases is MISSING: if flag.aliases is MISSING:
@ -229,7 +241,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.max_args is MISSING: if flag.max_args is MISSING:
flag.max_args = 1 flag.max_args = 1
else: else:
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag') raise TypeError(f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag")
if flag.override is MISSING: if flag.override is MISSING:
flag.override = False flag.override = False
@ -237,7 +249,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate flag names are unique # Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name name = flag.name.casefold() if case_insensitive else flag.name
if name in names: if name in names:
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.') raise TypeError(f"{flag.name!r} flag conflicts with previous flag or alias.")
else: else:
names.add(name) names.add(name)
@ -245,7 +257,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate alias is unique # Validate alias is unique
alias = alias.casefold() if case_insensitive else alias alias = alias.casefold() if case_insensitive else alias
if alias in names: if alias in names:
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.') raise TypeError(f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.")
else: else:
names.add(alias) names.add(alias)
@ -274,10 +286,10 @@ class FlagsMeta(type):
delimiter: str = MISSING, delimiter: str = MISSING,
prefix: str = MISSING, prefix: str = MISSING,
): ):
attrs['__commands_is_flag__'] = True attrs["__commands_is_flag__"] = True
try: try:
global_ns = sys.modules[attrs['__module__']].__dict__ global_ns = sys.modules[attrs["__module__"]].__dict__
except KeyError: except KeyError:
global_ns = {} global_ns = {}
@ -296,26 +308,26 @@ class FlagsMeta(type):
flags: Dict[str, Flag] = {} flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {} aliases: Dict[str, str] = {}
for base in reversed(bases): for base in reversed(bases):
if base.__dict__.get('__commands_is_flag__', False): if base.__dict__.get("__commands_is_flag__", False):
flags.update(base.__dict__['__commands_flags__']) flags.update(base.__dict__["__commands_flags__"])
aliases.update(base.__dict__['__commands_flag_aliases__']) aliases.update(base.__dict__["__commands_flag_aliases__"])
if case_insensitive is MISSING: if case_insensitive is MISSING:
attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__'] attrs["__commands_flag_case_insensitive__"] = base.__dict__["__commands_flag_case_insensitive__"]
if delimiter is MISSING: if delimiter is MISSING:
attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__'] attrs["__commands_flag_delimiter__"] = base.__dict__["__commands_flag_delimiter__"]
if prefix is MISSING: if prefix is MISSING:
attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__'] attrs["__commands_flag_prefix__"] = base.__dict__["__commands_flag_prefix__"]
if case_insensitive is not MISSING: if case_insensitive is not MISSING:
attrs['__commands_flag_case_insensitive__'] = case_insensitive attrs["__commands_flag_case_insensitive__"] = case_insensitive
if delimiter is not MISSING: if delimiter is not MISSING:
attrs['__commands_flag_delimiter__'] = delimiter attrs["__commands_flag_delimiter__"] = delimiter
if prefix is not MISSING: if prefix is not MISSING:
attrs['__commands_flag_prefix__'] = prefix attrs["__commands_flag_prefix__"] = prefix
case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False) case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':') delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
prefix = attrs.setdefault('__commands_flag_prefix__', '') prefix = attrs.setdefault("__commands_flag_prefix__", "")
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items(): for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag flags[flag_name] = flag
@ -337,11 +349,11 @@ class FlagsMeta(type):
keys.extend(re.escape(a) for a in aliases) keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True) keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = '|'.join(keys) joined = "|".join(keys)
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags) pattern = re.compile(f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})", regex_flags)
attrs['__commands_flag_regex__'] = pattern attrs["__commands_flag_regex__"] = pattern
attrs['__commands_flags__'] = flags attrs["__commands_flags__"] = flags
attrs['__commands_flag_aliases__'] = aliases attrs["__commands_flag_aliases__"] = aliases
return type.__new__(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
@ -432,7 +444,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
raise BadFlagArgument(flag) from e raise BadFlagArgument(flag) from e
F = TypeVar('F', bound='FlagConverter') F = TypeVar("F", bound="FlagConverter")
class FlagConverter(metaclass=FlagsMeta): class FlagConverter(metaclass=FlagsMeta):
@ -493,8 +505,8 @@ class FlagConverter(metaclass=FlagsMeta):
return self return self
def __repr__(self) -> str: def __repr__(self) -> str:
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()]) pairs = " ".join([f"{flag.attribute}={getattr(self, flag.attribute)!r}" for flag in self.get_flags().values()])
return f'<{self.__class__.__name__} {pairs}>' return f"<{self.__class__.__name__} {pairs}>"
@classmethod @classmethod
def parse_flags(cls, argument: str) -> Dict[str, List[str]]: def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
@ -507,7 +519,7 @@ class FlagConverter(metaclass=FlagsMeta):
case_insensitive = cls.__commands_flag_case_insensitive__ case_insensitive = cls.__commands_flag_case_insensitive__
for match in cls.__commands_flag_regex__.finditer(argument): for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0) begin, end = match.span(0)
key = match.group('flag') key = match.group("flag")
if case_insensitive: if case_insensitive:
key = key.casefold() key = key.casefold()

View File

@ -39,10 +39,10 @@ if TYPE_CHECKING:
from .context import Context from .context import Context
__all__ = ( __all__ = (
'Paginator', "Paginator",
'HelpCommand', "HelpCommand",
'DefaultHelpCommand', "DefaultHelpCommand",
'MinimalHelpCommand', "MinimalHelpCommand",
) )
# help -> shows info of bot on top/bottom and lists subcommands # help -> shows info of bot on top/bottom and lists subcommands
@ -89,7 +89,7 @@ class Paginator:
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'): def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"):
self.prefix = prefix self.prefix = prefix
self.suffix = suffix self.suffix = suffix
self.max_size = max_size self.max_size = max_size
@ -118,7 +118,7 @@ class Paginator:
def _linesep_len(self): def _linesep_len(self):
return len(self.linesep) return len(self.linesep)
def add_line(self, line='', *, empty=False): def add_line(self, line="", *, empty=False):
"""Adds a line to the current page. """Adds a line to the current page.
If the line exceeds the :attr:`max_size` then an exception If the line exceeds the :attr:`max_size` then an exception
@ -138,7 +138,7 @@ class Paginator:
""" """
max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
if len(line) > max_page_size: if len(line) > max_page_size:
raise RuntimeError(f'Line exceeds maximum page size {max_page_size}') raise RuntimeError(f"Line exceeds maximum page size {max_page_size}")
if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len: if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len:
self.close_page() self.close_page()
@ -147,7 +147,7 @@ class Paginator:
self._current_page.append(line) self._current_page.append(line)
if empty: if empty:
self._current_page.append('') self._current_page.append("")
self._count += self._linesep_len self._count += self._linesep_len
def close_page(self): def close_page(self):
@ -176,7 +176,7 @@ class Paginator:
return self._pages return self._pages
def __repr__(self): def __repr__(self):
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>' fmt = "<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>"
return fmt.format(self) return fmt.format(self)
@ -197,7 +197,7 @@ class _HelpCommandImpl(Command):
self.callback = injected.command_callback self.callback = injected.command_callback
on_error = injected.on_help_command_error on_error = injected.on_help_command_error
if not hasattr(on_error, '__help_command_not_overriden__'): if not hasattr(on_error, "__help_command_not_overriden__"):
if self.cog is not None: if self.cog is not None:
self.on_error = self._on_error_cog_implementation self.on_error = self._on_error_cog_implementation
else: else:
@ -224,7 +224,7 @@ class _HelpCommandImpl(Command):
try: try:
del result[next(iter(result))] del result[next(iter(result))]
except StopIteration: except StopIteration:
raise ValueError('Missing context parameter') from None raise ValueError("Missing context parameter") from None
else: else:
return result return result
@ -296,13 +296,13 @@ class HelpCommand:
""" """
MENTION_TRANSFORMS = { MENTION_TRANSFORMS = {
'@everyone': '@\u200beveryone', "@everyone": "@\u200beveryone",
'@here': '@\u200bhere', "@here": "@\u200bhere",
r'<@!?[0-9]{17,22}>': '@deleted-user', r"<@!?[0-9]{17,22}>": "@deleted-user",
r'<@&[0-9]{17,22}>': '@deleted-role', r"<@&[0-9]{17,22}>": "@deleted-role",
} }
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys()))
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# To prevent race conditions of a single instance while also allowing # To prevent race conditions of a single instance while also allowing
@ -321,11 +321,11 @@ class HelpCommand:
return self return self
def __init__(self, **options): def __init__(self, **options):
self.show_hidden = options.pop('show_hidden', False) self.show_hidden = options.pop("show_hidden", False)
self.verify_checks = options.pop('verify_checks', True) self.verify_checks = options.pop("verify_checks", True)
self.command_attrs = attrs = options.pop('command_attrs', {}) self.command_attrs = attrs = options.pop("command_attrs", {})
attrs.setdefault('name', 'help') attrs.setdefault("name", "help")
attrs.setdefault('help', 'Shows this message') attrs.setdefault("help", "Shows this message")
self.context: Context = discord.utils.MISSING self.context: Context = discord.utils.MISSING
self._command_impl = _HelpCommandImpl(self, **self.command_attrs) self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
@ -422,20 +422,20 @@ class HelpCommand:
if not parent.signature or parent.invoke_without_command: if not parent.signature or parent.invoke_without_command:
entries.append(parent.name) entries.append(parent.name)
else: else:
entries.append(parent.name + ' ' + parent.signature) entries.append(parent.name + " " + parent.signature)
parent = parent.parent parent = parent.parent
parent_sig = ' '.join(reversed(entries)) parent_sig = " ".join(reversed(entries))
if len(command.aliases) > 0: if len(command.aliases) > 0:
aliases = '|'.join(command.aliases) aliases = "|".join(command.aliases)
fmt = f'[{command.name}|{aliases}]' fmt = f"[{command.name}|{aliases}]"
if parent_sig: if parent_sig:
fmt = parent_sig + ' ' + fmt fmt = parent_sig + " " + fmt
alias = fmt alias = fmt
else: else:
alias = command.name if not parent_sig else parent_sig + ' ' + command.name alias = command.name if not parent_sig else parent_sig + " " + command.name
return f'{self.context.clean_prefix}{alias} {command.signature}' return f"{self.context.clean_prefix}{alias} {command.signature}"
def remove_mentions(self, string): def remove_mentions(self, string):
"""Removes mentions from the string to prevent abuse. """Removes mentions from the string to prevent abuse.
@ -449,7 +449,7 @@ class HelpCommand:
""" """
def replace(obj, *, transforms=self.MENTION_TRANSFORMS): def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
return transforms.get(obj.group(0), '@invalid') return transforms.get(obj.group(0), "@invalid")
return self.MENTION_PATTERN.sub(replace, string) return self.MENTION_PATTERN.sub(replace, string)
@ -615,7 +615,7 @@ class HelpCommand:
:class:`.abc.Messageable` :class:`.abc.Messageable`
The destination where the help command will be output. The destination where the help command will be output.
""" """
return self.context.channel return self.context
async def send_error_message(self, error): async def send_error_message(self, error):
"""|coro| """|coro|
@ -846,7 +846,7 @@ class HelpCommand:
# Since we want to have detailed errors when someone # Since we want to have detailed errors when someone
# passes an invalid subcommand, we need to walk through # passes an invalid subcommand, we need to walk through
# the command group chain ourselves. # the command group chain ourselves.
keys = command.split(' ') keys = command.split(" ")
cmd = bot.all_commands.get(keys[0]) cmd = bot.all_commands.get(keys[0])
if cmd is None: if cmd is None:
string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0])) string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0]))
@ -907,14 +907,14 @@ class DefaultHelpCommand(HelpCommand):
""" """
def __init__(self, **options): def __init__(self, **options):
self.width = options.pop('width', 80) self.width = options.pop("width", 80)
self.indent = options.pop('indent', 2) self.indent = options.pop("indent", 2)
self.sort_commands = options.pop('sort_commands', True) self.sort_commands = options.pop("sort_commands", True)
self.dm_help = options.pop('dm_help', False) self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000) self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.commands_heading = options.pop('commands_heading', "Commands:") self.commands_heading = options.pop("commands_heading", "Commands:")
self.no_category = options.pop('no_category', 'No Category') self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop('paginator', None) self.paginator = options.pop("paginator", None)
if self.paginator is None: if self.paginator is None:
self.paginator = Paginator() self.paginator = Paginator()
@ -924,7 +924,7 @@ class DefaultHelpCommand(HelpCommand):
def shorten_text(self, text): def shorten_text(self, text):
""":class:`str`: Shortens text to fit into the :attr:`width`.""" """:class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width: if len(text) > self.width:
return text[:self.width - 3].rstrip() + '...' return text[: self.width - 3].rstrip() + "..."
return text return text
def get_ending_note(self): def get_ending_note(self):
@ -977,6 +977,10 @@ class DefaultHelpCommand(HelpCommand):
for page in self.paginator.pages: for page in self.paginator.pages:
await destination.send(page) await destination.send(page)
interaction = self.context.interaction
if interaction is not None and destination == self.context.author and not interaction.response.is_done():
await interaction.response.send_message("Sent help to your DMs!", ephemeral=True)
def add_command_formatting(self, command): def add_command_formatting(self, command):
"""A utility function to format the non-indented block of commands and groups. """A utility function to format the non-indented block of commands and groups.
@ -1007,7 +1011,7 @@ class DefaultHelpCommand(HelpCommand):
elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold: elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold:
return ctx.author return ctx.author
else: else:
return ctx.channel return ctx
async def prepare_help_command(self, ctx, command): async def prepare_help_command(self, ctx, command):
self.paginator.clear() self.paginator.clear()
@ -1021,11 +1025,11 @@ class DefaultHelpCommand(HelpCommand):
# <description> portion # <description> portion
self.paginator.add_line(bot.description, empty=True) self.paginator.add_line(bot.description, empty=True)
no_category = f'\u200b{self.no_category}:' no_category = f"\u200b{self.no_category}:"
def get_category(command, *, no_category=no_category): def get_category(command, *, no_category=no_category):
cog = command.cog cog = command.cog
return cog.qualified_name + ':' if cog is not None else no_category return cog.qualified_name + ":" if cog is not None else no_category
filtered = await self.filter_commands(bot.commands, sort=True, key=get_category) filtered = await self.filter_commands(bot.commands, sort=True, key=get_category)
max_size = self.get_max_size(filtered) max_size = self.get_max_size(filtered)
@ -1110,13 +1114,13 @@ class MinimalHelpCommand(HelpCommand):
""" """
def __init__(self, **options): def __init__(self, **options):
self.sort_commands = options.pop('sort_commands', True) self.sort_commands = options.pop("sort_commands", True)
self.commands_heading = options.pop('commands_heading', "Commands") self.commands_heading = options.pop("commands_heading", "Commands")
self.dm_help = options.pop('dm_help', False) self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000) self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.aliases_heading = options.pop('aliases_heading', "Aliases:") self.aliases_heading = options.pop("aliases_heading", "Aliases:")
self.no_category = options.pop('no_category', 'No Category') self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop('paginator', None) self.paginator = options.pop("paginator", None)
if self.paginator is None: if self.paginator is None:
self.paginator = Paginator(suffix=None, prefix=None) self.paginator = Paginator(suffix=None, prefix=None)
@ -1149,7 +1153,7 @@ class MinimalHelpCommand(HelpCommand):
) )
def get_command_signature(self, command): def get_command_signature(self, command):
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}' return f"{self.context.clean_prefix}{command.qualified_name} {command.signature}"
def get_ending_note(self): def get_ending_note(self):
"""Return the help command's ending note. This is mainly useful to override for i18n purposes. """Return the help command's ending note. This is mainly useful to override for i18n purposes.
@ -1180,8 +1184,8 @@ class MinimalHelpCommand(HelpCommand):
""" """
if commands: if commands:
# U+2002 Middle Dot # U+2002 Middle Dot
joined = '\u2002'.join(c.name for c in commands) joined = "\u2002".join(c.name for c in commands)
self.paginator.add_line(f'__**{heading}**__') self.paginator.add_line(f"__**{heading}**__")
self.paginator.add_line(joined) self.paginator.add_line(joined)
def add_subcommand_formatting(self, command): def add_subcommand_formatting(self, command):
@ -1197,7 +1201,7 @@ class MinimalHelpCommand(HelpCommand):
command: :class:`Command` command: :class:`Command`
The command to show information of. The command to show information of.
""" """
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}' fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}"
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc)) self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
def add_aliases_formatting(self, aliases): def add_aliases_formatting(self, aliases):
@ -1268,7 +1272,7 @@ class MinimalHelpCommand(HelpCommand):
if note: if note:
self.paginator.add_line(note, empty=True) self.paginator.add_line(note, empty=True)
no_category = f'\u200b{self.no_category}' no_category = f"\u200b{self.no_category}"
def get_category(command, *, no_category=no_category): def get_category(command, *, no_category=no_category):
cog = command.cog cog = command.cog
@ -1302,7 +1306,7 @@ class MinimalHelpCommand(HelpCommand):
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands) filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
if filtered: if filtered:
self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**') self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**")
for command in filtered: for command in filtered:
self.add_subcommand_formatting(command) self.add_subcommand_formatting(command)
@ -1322,7 +1326,7 @@ class MinimalHelpCommand(HelpCommand):
if note: if note:
self.paginator.add_line(note, empty=True) self.paginator.add_line(note, empty=True)
self.paginator.add_line(f'**{self.commands_heading}**') self.paginator.add_line(f"**{self.commands_heading}**")
for command in filtered: for command in filtered:
self.add_subcommand_formatting(command) self.add_subcommand_formatting(command)

View File

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes # map from opening quotes to closing quotes
_quotes = { supported_quotes = {
'"': '"', '"': '"',
"": "", "": "",
"": "", "": "",
@ -44,7 +44,8 @@ _quotes = {
"": "", "": "",
"": "", "": "",
} }
_all_quotes = set(_quotes.keys()) | set(_quotes.values()) _all_quotes = set(supported_quotes.keys()) | set(supported_quotes.values())
class StringView: class StringView:
def __init__(self, buffer): def __init__(self, buffer):
@ -81,20 +82,20 @@ class StringView:
def skip_string(self, string): def skip_string(self, string):
strlen = len(string) strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string: if self.buffer[self.index : self.index + strlen] == string:
self.previous = self.index self.previous = self.index
self.index += strlen self.index += strlen
return True return True
return False return False
def read_rest(self): def read_rest(self):
result = self.buffer[self.index:] result = self.buffer[self.index :]
self.previous = self.index self.previous = self.index
self.index = self.end self.index = self.end
return result return result
def read(self, n): def read(self, n):
result = self.buffer[self.index:self.index + n] result = self.buffer[self.index : self.index + n]
self.previous = self.index self.previous = self.index
self.index += n self.index += n
return result return result
@ -120,7 +121,7 @@ class StringView:
except IndexError: except IndexError:
break break
self.previous = self.index self.previous = self.index
result = self.buffer[self.index:self.index + pos] result = self.buffer[self.index : self.index + pos]
self.index += pos self.index += pos
return result return result
@ -129,7 +130,7 @@ class StringView:
if current is None: if current is None:
return None return None
close_quote = _quotes.get(current) close_quote = supported_quotes.get(current)
is_quoted = bool(close_quote) is_quoted = bool(close_quote)
if is_quoted: if is_quoted:
result = [] result = []
@ -144,11 +145,11 @@ class StringView:
if is_quoted: if is_quoted:
# unexpected EOF # unexpected EOF
raise ExpectedClosingQuoteError(close_quote) raise ExpectedClosingQuoteError(close_quote)
return ''.join(result) return "".join(result)
# currently we accept strings in the format of "hello world" # currently we accept strings in the format of "hello world"
# to embed a quote inside the string you must escape it: "a \"world\"" # to embed a quote inside the string you must escape it: "a \"world\""
if current == '\\': if current == "\\":
next_char = self.get() next_char = self.get()
if not next_char: if not next_char:
# string ends with \ and no character after it # string ends with \ and no character after it
@ -156,7 +157,7 @@ class StringView:
# if we're quoted then we're expecting a closing quote # if we're quoted then we're expecting a closing quote
raise ExpectedClosingQuoteError(close_quote) raise ExpectedClosingQuoteError(close_quote)
# if we aren't then we just let it through # if we aren't then we just let it through
return ''.join(result) return "".join(result)
if next_char in _escaped_quotes: if next_char in _escaped_quotes:
# escaped quote # escaped quote
@ -179,14 +180,13 @@ class StringView:
raise InvalidEndOfQuotedStringError(next_char) raise InvalidEndOfQuotedStringError(next_char)
# we're quoted so it's okay # we're quoted so it's okay
return ''.join(result) return "".join(result)
if current.isspace() and not is_quoted: if current.isspace() and not is_quoted:
# end of word found # end of word found
return ''.join(result) return "".join(result)
result.append(current) result.append(current)
def __repr__(self): def __repr__(self):
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>' return f"<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>"

View File

@ -48,19 +48,17 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING
__all__ = ( __all__ = ("loop",)
'loop',
)
T = TypeVar('T') T = TypeVar("T")
_func = Callable[..., Awaitable[Any]] _func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func) LF = TypeVar("LF", bound=_func)
FT = TypeVar('FT', bound=_func) FT = TypeVar("FT", bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
class SleepHandle: class SleepHandle:
__slots__ = ('future', 'loop', 'handle') __slots__ = ("future", "loop", "handle")
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop self.loop = loop
@ -124,7 +122,7 @@ class Loop(Generic[LF]):
self._stop_next_iteration = False self._stop_next_iteration = False
if self.count is not None and self.count <= 0: if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.') raise ValueError("count must be greater than 0 or None.")
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False self._last_iteration_failed = False
@ -132,10 +130,10 @@ class Loop(Generic[LF]):
self._next_iteration = None self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro): if not inspect.iscoroutinefunction(self.coro):
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.') raise TypeError(f"Expected coroutine function, not {type(self.coro).__name__!r}.")
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
coro = getattr(self, '_' + name) coro = getattr(self, "_" + name)
if coro is None: if coro is None:
return return
@ -150,7 +148,7 @@ class Loop(Generic[LF]):
async def _loop(self, *args: Any, **kwargs: Any) -> None: async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff() backoff = ExponentialBackoff()
await self._call_loop_function('before_loop') await self._call_loop_function("before_loop")
self._last_iteration_failed = False self._last_iteration_failed = False
if self._time is not MISSING: if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started # the time index should be prepared every time the internal loop is started
@ -193,10 +191,10 @@ class Loop(Generic[LF]):
raise raise
except Exception as exc: except Exception as exc:
self._has_failed = True self._has_failed = True
await self._call_loop_function('error', exc) await self._call_loop_function("error", exc)
raise exc raise exc
finally: finally:
await self._call_loop_function('after_loop') await self._call_loop_function("after_loop")
self._handle.cancel() self._handle.cancel()
self._is_being_cancelled = False self._is_being_cancelled = False
self._current_loop = 0 self._current_loop = 0
@ -323,7 +321,7 @@ class Loop(Generic[LF]):
""" """
if self._task is not MISSING and not self._task.done(): if self._task is not MISSING and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.') raise RuntimeError("Task is already launched and is not completed.")
if self._injected is not None: if self._injected is not None:
args = (self._injected, *args) args = (self._injected, *args)
@ -410,9 +408,9 @@ class Loop(Generic[LF]):
for exc in exceptions: for exc in exceptions:
if not inspect.isclass(exc): if not inspect.isclass(exc):
raise TypeError(f'{exc!r} must be a class.') raise TypeError(f"{exc!r} must be a class.")
if not issubclass(exc, BaseException): if not issubclass(exc, BaseException):
raise TypeError(f'{exc!r} must inherit from BaseException.') raise TypeError(f"{exc!r} must inherit from BaseException.")
self._valid_exception = (*self._valid_exception, *exceptions) self._valid_exception = (*self._valid_exception, *exceptions)
@ -466,7 +464,7 @@ class Loop(Generic[LF]):
async def _error(self, *args: Any) -> None: async def _error(self, *args: Any) -> None:
exception: Exception = args[-1] exception: Exception = args[-1]
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr) print(f"Unhandled exception in internal background task {self.coro.__name__!r}.", file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
def before_loop(self, coro: FT) -> FT: def before_loop(self, coro: FT) -> FT:
@ -489,7 +487,7 @@ class Loop(Generic[LF]):
""" """
if not inspect.iscoroutinefunction(coro): if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._before_loop = coro self._before_loop = coro
return coro return coro
@ -517,7 +515,7 @@ class Loop(Generic[LF]):
""" """
if not inspect.iscoroutinefunction(coro): if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._after_loop = coro self._after_loop = coro
return coro return coro
@ -543,7 +541,7 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
self._error = coro # type: ignore self._error = coro # type: ignore
return coro return coro
@ -601,16 +599,16 @@ class Loop(Generic[LF]):
return [inner] return [inner]
if not isinstance(time, Sequence): if not isinstance(time, Sequence):
raise TypeError( raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.' f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead."
) )
if not time: if not time:
raise ValueError('time parameter must not be an empty sequence.') raise ValueError("time parameter must not be an empty sequence.")
ret: List[datetime.time] = [] ret: List[datetime.time] = []
for index, t in enumerate(time): for index, t in enumerate(time):
if not isinstance(t, dt): if not isinstance(t, dt):
raise TypeError( raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.' f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead."
) )
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
@ -663,7 +661,7 @@ class Loop(Generic[LF]):
hours = hours or 0 hours = hours or 0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0) sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0: if sleep < 0:
raise ValueError('Total number of seconds cannot be less than zero.') raise ValueError("Total number of seconds cannot be less than zero.")
self._sleep = sleep self._sleep = sleep
self._seconds = float(seconds) self._seconds = float(seconds)
@ -672,7 +670,7 @@ class Loop(Generic[LF]):
self._time: List[datetime.time] = MISSING self._time: List[datetime.time] = MISSING
else: else:
if any((seconds, minutes, hours)): if any((seconds, minutes, hours)):
raise TypeError('Cannot mix explicit time with relative time') raise TypeError("Cannot mix explicit time with relative time")
self._time = self._get_time_parameter(time) self._time = self._get_time_parameter(time)
self._sleep = self._seconds = self._minutes = self._hours = MISSING self._sleep = self._seconds = self._minutes = self._hours = MISSING

View File

@ -28,9 +28,7 @@ from typing import Optional, TYPE_CHECKING, Union
import os import os
import io import io
__all__ = ( __all__ = ("File",)
'File',
)
class File: class File:
@ -64,7 +62,7 @@ class File:
Whether the attachment is a spoiler. Whether the attachment is a spoiler.
""" """
__slots__ = ('fp', 'filename', 'spoiler', '_original_pos', '_owner', '_closer') __slots__ = ("fp", "filename", "spoiler", "_original_pos", "_owner", "_closer")
if TYPE_CHECKING: if TYPE_CHECKING:
fp: io.BufferedIOBase fp: io.BufferedIOBase
@ -80,12 +78,12 @@ class File:
): ):
if isinstance(fp, io.IOBase): if isinstance(fp, io.IOBase):
if not (fp.seekable() and fp.readable()): if not (fp.seekable() and fp.readable()):
raise ValueError(f'File buffer {fp!r} must be seekable and readable') raise ValueError(f"File buffer {fp!r} must be seekable and readable")
self.fp = fp self.fp = fp
self._original_pos = fp.tell() self._original_pos = fp.tell()
self._owner = False self._owner = False
else: else:
self.fp = open(fp, 'rb') self.fp = open(fp, "rb")
self._original_pos = 0 self._original_pos = 0
self._owner = True self._owner = True
@ -100,14 +98,14 @@ class File:
if isinstance(fp, str): if isinstance(fp, str):
_, self.filename = os.path.split(fp) _, self.filename = os.path.split(fp)
else: else:
self.filename = getattr(fp, 'name', None) self.filename = getattr(fp, "name", None)
else: else:
self.filename = filename self.filename = filename
if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'): if spoiler and self.filename is not None and not self.filename.startswith("SPOILER_"):
self.filename = 'SPOILER_' + self.filename self.filename = "SPOILER_" + self.filename
self.spoiler = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_')) self.spoiler = spoiler or (self.filename is not None and self.filename.startswith("SPOILER_"))
def reset(self, *, seek: Union[int, bool] = True) -> None: def reset(self, *, seek: Union[int, bool] = True) -> None:
# The `seek` parameter is needed because # The `seek` parameter is needed because

View File

@ -29,16 +29,16 @@ from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optio
from .enums import UserFlags from .enums import UserFlags
__all__ = ( __all__ = (
'SystemChannelFlags', "SystemChannelFlags",
'MessageFlags', "MessageFlags",
'PublicUserFlags', "PublicUserFlags",
'Intents', "Intents",
'MemberCacheFlags', "MemberCacheFlags",
'ApplicationFlags', "ApplicationFlags",
) )
FV = TypeVar('FV', bound='flag_value') FV = TypeVar("FV", bound="flag_value")
BF = TypeVar('BF', bound='BaseFlags') BF = TypeVar("BF", bound="BaseFlags")
class flag_value: class flag_value:
@ -63,7 +63,7 @@ class flag_value:
instance._set_flag(self.flag, value) instance._set_flag(self.flag, value)
def __repr__(self): def __repr__(self):
return f'<flag_value flag={self.flag!r}>' return f"<flag_value flag={self.flag!r}>"
class alias_flag_value(flag_value): class alias_flag_value(flag_value):
@ -98,13 +98,13 @@ class BaseFlags:
value: int value: int
__slots__ = ('value',) __slots__ = ("value",)
def __init__(self, **kwargs: bool): def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_VALUE self.value = self.DEFAULT_VALUE
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.') raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
@ -123,7 +123,7 @@ class BaseFlags:
return hash(self.value) return hash(self.value)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} value={self.value}>' return f"<{self.__class__.__name__} value={self.value}>"
def __iter__(self) -> Iterator[Tuple[str, bool]]: def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name, value in self.__class__.__dict__.items(): for name, value in self.__class__.__dict__.items():
@ -142,7 +142,7 @@ class BaseFlags:
elif toggle is False: elif toggle is False:
self.value &= ~o self.value &= ~o
else: else:
raise TypeError(f'Value to set for {self.__class__.__name__} must be a bool.') raise TypeError(f"Value to set for {self.__class__.__name__} must be a bool.")
@fill_with_flags(inverted=True) @fill_with_flags(inverted=True)
@ -196,7 +196,7 @@ class SystemChannelFlags(BaseFlags):
elif toggle is False: elif toggle is False:
self.value |= o self.value |= o
else: else:
raise TypeError('Value to set for SystemChannelFlags must be a bool.') raise TypeError("Value to set for SystemChannelFlags must be a bool.")
@flag_value @flag_value
def join_notifications(self): def join_notifications(self):
@ -461,7 +461,7 @@ class Intents(BaseFlags):
self.value = self.DEFAULT_VALUE self.value = self.DEFAULT_VALUE
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.') raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
@ -480,16 +480,6 @@ class Intents(BaseFlags):
self.value = self.DEFAULT_VALUE self.value = self.DEFAULT_VALUE
return self return self
@classmethod
def default(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything enabled
except :attr:`presences` and :attr:`members`.
"""
self = cls.all()
self.presences = False
self.members = False
return self
@flag_value @flag_value
def guilds(self): def guilds(self):
""":class:`bool`: Whether guild related events are enabled. """:class:`bool`: Whether guild related events are enabled.
@ -917,7 +907,7 @@ class MemberCacheFlags(BaseFlags):
self.value = (1 << bits) - 1 self.value = (1 << bits) - 1
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid flag name.') raise TypeError(f"{key!r} is not a valid flag name.")
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
@ -987,10 +977,10 @@ class MemberCacheFlags(BaseFlags):
def _verify_intents(self, intents: Intents): def _verify_intents(self, intents: Intents):
if self.voice and not intents.voice_states: if self.voice and not intents.voice_states:
raise ValueError('MemberCacheFlags.voice requires Intents.voice_states') raise ValueError("MemberCacheFlags.voice requires Intents.voice_states")
if self.joined and not intents.members: if self.joined and not intents.members:
raise ValueError('MemberCacheFlags.joined requires Intents.members') raise ValueError("MemberCacheFlags.joined requires Intents.members")
@property @property
def _voice_only(self): def _voice_only(self):

File diff suppressed because it is too large Load Diff

View File

@ -46,7 +46,7 @@ from . import utils, abc
from .role import Role from .role import Role
from .member import Member, VoiceState from .member import Member, VoiceState
from .emoji import Emoji from .emoji import Emoji
from .errors import InvalidData from .errors import InvalidData, NotFound
from .permissions import PermissionOverwrite from .permissions import PermissionOverwrite
from .colour import Colour from .colour import Colour
from .errors import InvalidArgument, ClientException from .errors import InvalidArgument, ClientException
@ -78,9 +78,7 @@ from .sticker import GuildSticker
from .file import File from .file import File
__all__ = ( __all__ = ("Guild",)
'Guild',
)
MISSING = utils.MISSING MISSING = utils.MISSING
@ -140,6 +138,10 @@ class Guild(Hashable):
Returns the guild's name. Returns the guild's name.
.. describe:: int(x)
Returns the guild's ID.
Attributes Attributes
---------- ----------
name: :class:`str` name: :class:`str`
@ -235,45 +237,45 @@ class Guild(Hashable):
""" """
__slots__ = ( __slots__ = (
'afk_timeout', "afk_timeout",
'afk_channel', "afk_channel",
'name', "name",
'id', "id",
'unavailable', "unavailable",
'region', "region",
'owner_id', "owner_id",
'mfa_level', "mfa_level",
'emojis', "emojis",
'stickers', "stickers",
'features', "features",
'verification_level', "verification_level",
'explicit_content_filter', "explicit_content_filter",
'default_notifications', "default_notifications",
'description', "description",
'max_presences', "max_presences",
'max_members', "max_members",
'max_video_channel_users', "max_video_channel_users",
'premium_tier', "premium_tier",
'premium_subscription_count', "premium_subscription_count",
'preferred_locale', "preferred_locale",
'nsfw_level', "nsfw_level",
'_members', "_members",
'_channels', "_channels",
'_icon', "_icon",
'_banner', "_banner",
'_state', "_state",
'_roles', "_roles",
'_member_count', "_member_count",
'_large', "_large",
'_splash', "_splash",
'_voice_states', "_voice_states",
'_system_channel_id', "_system_channel_id",
'_system_channel_flags', "_system_channel_flags",
'_discovery_splash', "_discovery_splash",
'_rules_channel_id', "_rules_channel_id",
'_public_updates_channel_id', "_public_updates_channel_id",
'_stage_instances', "_stage_instances",
'_threads', "_threads",
) )
_PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = { _PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = {
@ -333,21 +335,23 @@ class Guild(Hashable):
return to_remove return to_remove
def __str__(self) -> str: def __str__(self) -> str:
return self.name or '' return self.name or ""
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = ( attrs = (
('id', self.id), ("id", self.id),
('name', self.name), ("name", self.name),
('shard_id', self.shard_id), ("shard_id", self.shard_id),
('chunked', self.chunked), ("chunked", self.chunked),
('member_count', getattr(self, '_member_count', None)), ("member_count", getattr(self, "_member_count", None)),
) )
inner = ' '.join('%s=%r' % t for t in attrs) inner = " ".join("%s=%r" % t for t in attrs)
return f'<Guild {inner}>' return f"<Guild {inner}>"
def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]: def _update_voice_state(
user_id = int(data['user_id']) self, data: GuildVoiceState, channel_id: int
) -> Tuple[Optional[Member], VoiceState, VoiceState]:
user_id = int(data["user_id"])
channel = self.get_channel(channel_id) channel = self.get_channel(channel_id)
try: try:
# check if we should remove the voice state from cache # check if we should remove the voice state from cache
@ -367,7 +371,7 @@ class Guild(Hashable):
member = self.get_member(user_id) member = self.get_member(user_id)
if member is None: if member is None:
try: try:
member = Member(data=data['member'], state=self._state, guild=self) member = Member(data=data["member"], state=self._state, guild=self)
except KeyError: except KeyError:
member = None member = None
@ -399,57 +403,57 @@ class Guild(Hashable):
def _from_data(self, guild: GuildPayload) -> None: def _from_data(self, guild: GuildPayload) -> None:
# according to Stan, this is always available even if the guild is unavailable # according to Stan, this is always available even if the guild is unavailable
# I don't have this guarantee when someone updates the guild. # I don't have this guarantee when someone updates the guild.
member_count = guild.get('member_count', None) member_count = guild.get("member_count", None)
if member_count is not None: if member_count is not None:
self._member_count: int = member_count self._member_count: int = member_count
self.name: str = guild.get('name') self.name: str = guild.get("name")
self.region: VoiceRegion = try_enum(VoiceRegion, guild.get('region')) self.region: VoiceRegion = try_enum(VoiceRegion, guild.get("region"))
self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get('verification_level')) self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get("verification_level"))
self.default_notifications: NotificationLevel = try_enum( self.default_notifications: NotificationLevel = try_enum(
NotificationLevel, guild.get('default_message_notifications') NotificationLevel, guild.get("default_message_notifications")
) )
self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get('explicit_content_filter', 0)) self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get("explicit_content_filter", 0))
self.afk_timeout: int = guild.get('afk_timeout') self.afk_timeout: int = guild.get("afk_timeout")
self._icon: Optional[str] = guild.get('icon') self._icon: Optional[str] = guild.get("icon")
self._banner: Optional[str] = guild.get('banner') self._banner: Optional[str] = guild.get("banner")
self.unavailable: bool = guild.get('unavailable', False) self.unavailable: bool = guild.get("unavailable", False)
self.id: int = int(guild['id']) self.id: int = int(guild["id"])
self._roles: Dict[int, Role] = {} self._roles: Dict[int, Role] = {}
state = self._state # speed up attribute access state = self._state # speed up attribute access
for r in guild.get('roles', []): for r in guild.get("roles", []):
role = Role(guild=self, data=r, state=state) role = Role(guild=self, data=r, state=state)
self._roles[role.id] = role self._roles[role.id] = role
self.mfa_level: MFALevel = guild.get('mfa_level') self.mfa_level: MFALevel = guild.get("mfa_level")
self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', []))) self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get("emojis", [])))
self.stickers: Tuple[GuildSticker, ...] = tuple( self.stickers: Tuple[GuildSticker, ...] = tuple(
map(lambda d: state.store_sticker(self, d), guild.get('stickers', [])) map(lambda d: state.store_sticker(self, d), guild.get("stickers", []))
) )
self.features: List[GuildFeature] = guild.get('features', []) self.features: List[GuildFeature] = guild.get("features", [])
self._splash: Optional[str] = guild.get('splash') self._splash: Optional[str] = guild.get("splash")
self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'system_channel_id') self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, "system_channel_id")
self.description: Optional[str] = guild.get('description') self.description: Optional[str] = guild.get("description")
self.max_presences: Optional[int] = guild.get('max_presences') self.max_presences: Optional[int] = guild.get("max_presences")
self.max_members: Optional[int] = guild.get('max_members') self.max_members: Optional[int] = guild.get("max_members")
self.max_video_channel_users: Optional[int] = guild.get('max_video_channel_users') self.max_video_channel_users: Optional[int] = guild.get("max_video_channel_users")
self.premium_tier: int = guild.get('premium_tier', 0) self.premium_tier: int = guild.get("premium_tier", 0)
self.premium_subscription_count: int = guild.get('premium_subscription_count') or 0 self.premium_subscription_count: int = guild.get("premium_subscription_count") or 0
self._system_channel_flags: int = guild.get('system_channel_flags', 0) self._system_channel_flags: int = guild.get("system_channel_flags", 0)
self.preferred_locale: Optional[str] = guild.get('preferred_locale') self.preferred_locale: Optional[str] = guild.get("preferred_locale")
self._discovery_splash: Optional[str] = guild.get('discovery_splash') self._discovery_splash: Optional[str] = guild.get("discovery_splash")
self._rules_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'rules_channel_id') self._rules_channel_id: Optional[int] = utils._get_as_snowflake(guild, "rules_channel_id")
self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'public_updates_channel_id') self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake(guild, "public_updates_channel_id")
self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get('nsfw_level', 0)) self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get("nsfw_level", 0))
self._stage_instances: Dict[int, StageInstance] = {} self._stage_instances: Dict[int, StageInstance] = {}
for s in guild.get('stage_instances', []): for s in guild.get("stage_instances", []):
stage_instance = StageInstance(guild=self, data=s, state=state) stage_instance = StageInstance(guild=self, data=s, state=state)
self._stage_instances[stage_instance.id] = stage_instance self._stage_instances[stage_instance.id] = stage_instance
cache_joined = self._state.member_cache_flags.joined cache_joined = self._state.member_cache_flags.joined
self_id = self._state.self_id self_id = self._state.self_id
for mdata in guild.get('members', []): for mdata in guild.get("members", []):
member = Member(data=mdata, guild=self, state=state) member = Member(data=mdata, guild=self, state=state)
if cache_joined or member.id == self_id: if cache_joined or member.id == self_id:
self._add_member(member) self._add_member(member)
@ -457,35 +461,35 @@ class Guild(Hashable):
self._sync(guild) self._sync(guild)
self._large: Optional[bool] = None if member_count is None else self._member_count >= 250 self._large: Optional[bool] = None if member_count is None else self._member_count >= 250
self.owner_id: Optional[int] = utils._get_as_snowflake(guild, 'owner_id') self.owner_id: Optional[int] = utils._get_as_snowflake(guild, "owner_id")
self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, 'afk_channel_id')) # type: ignore self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, "afk_channel_id")) # type: ignore
for obj in guild.get('voice_states', []): for obj in guild.get("voice_states", []):
self._update_voice_state(obj, int(obj['channel_id'])) self._update_voice_state(obj, int(obj["channel_id"]))
# TODO: refactor/remove? # TODO: refactor/remove?
def _sync(self, data: GuildPayload) -> None: def _sync(self, data: GuildPayload) -> None:
try: try:
self._large = data['large'] self._large = data["large"]
except KeyError: except KeyError:
pass pass
empty_tuple = tuple() empty_tuple = tuple()
for presence in data.get('presences', []): for presence in data.get("presences", []):
user_id = int(presence['user']['id']) user_id = int(presence["user"]["id"])
member = self.get_member(user_id) member = self.get_member(user_id)
if member is not None: if member is not None:
member._presence_update(presence, empty_tuple) # type: ignore member._presence_update(presence, empty_tuple) # type: ignore
if 'channels' in data: if "channels" in data:
channels = data['channels'] channels = data["channels"]
for c in channels: for c in channels:
factory, ch_type = _guild_channel_factory(c['type']) factory, ch_type = _guild_channel_factory(c["type"])
if factory: if factory:
self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore
if 'threads' in data: if "threads" in data:
threads = data['threads'] threads = data["threads"]
for thread in threads: for thread in threads:
self._add_thread(Thread(guild=self, state=self._state, data=thread)) self._add_thread(Thread(guild=self, state=self._state, data=thread))
@ -708,7 +712,7 @@ class Guild(Hashable):
@property @property
def emoji_limit(self) -> int: def emoji_limit(self) -> int:
""":class:`int`: The maximum number of emoji slots this guild has.""" """:class:`int`: The maximum number of emoji slots this guild has."""
more_emoji = 200 if 'MORE_EMOJI' in self.features else 50 more_emoji = 200 if "MORE_EMOJI" in self.features else 50
return max(more_emoji, self._PREMIUM_GUILD_LIMITS[self.premium_tier].emoji) return max(more_emoji, self._PREMIUM_GUILD_LIMITS[self.premium_tier].emoji)
@property @property
@ -717,13 +721,13 @@ class Guild(Hashable):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
more_stickers = 60 if 'MORE_STICKERS' in self.features else 0 more_stickers = 60 if "MORE_STICKERS" in self.features else 0
return max(more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers) return max(more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers)
@property @property
def bitrate_limit(self) -> float: def bitrate_limit(self) -> float:
""":class:`float`: The maximum bitrate for voice channels this guild can have.""" """:class:`float`: The maximum bitrate for voice channels this guild can have."""
vip_guild = self._PREMIUM_GUILD_LIMITS[1].bitrate if 'VIP_REGIONS' in self.features else 96e3 vip_guild = self._PREMIUM_GUILD_LIMITS[1].bitrate if "VIP_REGIONS" in self.features else 96e3
return max(vip_guild, self._PREMIUM_GUILD_LIMITS[self.premium_tier].bitrate) return max(vip_guild, self._PREMIUM_GUILD_LIMITS[self.premium_tier].bitrate)
@property @property
@ -738,12 +742,16 @@ class Guild(Hashable):
@property @property
def humans(self) -> List[Member]: def humans(self) -> List[Member]:
"""List[:class:`Member`]: A list of human members that belong to this guild.""" """List[:class:`Member`]: A list of human members that belong to this guild.
.. versionadded:: 2.0"""
return [member for member in self.members if not member.bot] return [member for member in self.members if not member.bot]
@property @property
def bots(self) -> List[Member]: def bots(self) -> List[Member]:
"""List[:class:`Member`]: A list of bots that belong to this guild.""" """List[:class:`Member`]: A list of bots that belong to this guild.
.. versionadded:: 2.0"""
return [member for member in self.members if member.bot] return [member for member in self.members if member.bot]
def get_member(self, user_id: int, /) -> Optional[Member]: def get_member(self, user_id: int, /) -> Optional[Member]:
@ -863,21 +871,21 @@ class Guild(Hashable):
"""Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" """Optional[:class:`Asset`]: Returns the guild's banner asset, if available."""
if self._banner is None: if self._banner is None:
return None return None
return Asset._from_guild_image(self._state, self.id, self._banner, path='banners') return Asset._from_guild_image(self._state, self.id, self._banner, path="banners")
@property @property
def splash(self) -> Optional[Asset]: def splash(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available."""
if self._splash is None: if self._splash is None:
return None return None
return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes') return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes")
@property @property
def discovery_splash(self) -> Optional[Asset]: def discovery_splash(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's discovery splash asset, if available.""" """Optional[:class:`Asset`]: Returns the guild's discovery splash asset, if available."""
if self._discovery_splash is None: if self._discovery_splash is None:
return None return None
return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path='discovery-splashes') return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path="discovery-splashes")
@property @property
def member_count(self) -> int: def member_count(self) -> int:
@ -901,7 +909,7 @@ class Guild(Hashable):
If this value returns ``False``, then you should request for If this value returns ``False``, then you should request for
offline members. offline members.
""" """
count = getattr(self, '_member_count', None) count = getattr(self, "_member_count", None)
if count is None: if count is None:
return False return False
return count == len(self._members) return count == len(self._members)
@ -948,7 +956,7 @@ class Guild(Hashable):
result = None result = None
members = self.members members = self.members
if len(name) > 5 and name[-5] == '#': if len(name) > 5 and name[-5] == "#":
# The 5 length is checking to see if #0000 is in the string, # The 5 length is checking to see if #0000 is in the string,
# as a#0000 has a length of 6, the minimum for a potential # as a#0000 has a length of 6, the minimum for a potential
# discriminator lookup. # discriminator lookup.
@ -976,20 +984,20 @@ class Guild(Hashable):
if overwrites is MISSING: if overwrites is MISSING:
overwrites = {} overwrites = {}
elif not isinstance(overwrites, dict): elif not isinstance(overwrites, dict):
raise InvalidArgument('overwrites parameter expects a dict.') raise InvalidArgument("overwrites parameter expects a dict.")
perms = [] perms = []
for target, perm in overwrites.items(): for target, perm in overwrites.items():
if not isinstance(perm, PermissionOverwrite): if not isinstance(perm, PermissionOverwrite):
raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}') raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}")
allow, deny = perm.pair() allow, deny = perm.pair()
payload = {'allow': allow.value, 'deny': deny.value, 'id': target.id} payload = {"allow": allow.value, "deny": deny.value, "id": target.id}
if isinstance(target, Role): if isinstance(target, Role):
payload['type'] = abc._Overwrites.ROLE payload["type"] = abc._Overwrites.ROLE
else: else:
payload['type'] = abc._Overwrites.MEMBER payload["type"] = abc._Overwrites.MEMBER
perms.append(payload) perms.append(payload)
@ -1090,16 +1098,16 @@ class Guild(Hashable):
options = {} options = {}
if position is not MISSING: if position is not MISSING:
options['position'] = position options["position"] = position
if topic is not MISSING: if topic is not MISSING:
options['topic'] = topic options["topic"] = topic
if slowmode_delay is not MISSING: if slowmode_delay is not MISSING:
options['rate_limit_per_user'] = slowmode_delay options["rate_limit_per_user"] = slowmode_delay
if nsfw is not MISSING: if nsfw is not MISSING:
options['nsfw'] = nsfw options["nsfw"] = nsfw
data = await self._create_channel( data = await self._create_channel(
name, overwrites=overwrites, channel_type=ChannelType.text, category=category, reason=reason, **options name, overwrites=overwrites, channel_type=ChannelType.text, category=category, reason=reason, **options
@ -1174,19 +1182,19 @@ class Guild(Hashable):
""" """
options = {} options = {}
if position is not MISSING: if position is not MISSING:
options['position'] = position options["position"] = position
if bitrate is not MISSING: if bitrate is not MISSING:
options['bitrate'] = bitrate options["bitrate"] = bitrate
if user_limit is not MISSING: if user_limit is not MISSING:
options['user_limit'] = user_limit options["user_limit"] = user_limit
if rtc_region is not MISSING: if rtc_region is not MISSING:
options['rtc_region'] = None if rtc_region is None else str(rtc_region) options["rtc_region"] = None if rtc_region is None else str(rtc_region)
if video_quality_mode is not MISSING: if video_quality_mode is not MISSING:
options['video_quality_mode'] = video_quality_mode.value options["video_quality_mode"] = video_quality_mode.value
data = await self._create_channel( data = await self._create_channel(
name, overwrites=overwrites, channel_type=ChannelType.voice, category=category, reason=reason, **options name, overwrites=overwrites, channel_type=ChannelType.voice, category=category, reason=reason, **options
@ -1249,13 +1257,18 @@ class Guild(Hashable):
""" """
options: Dict[str, Any] = { options: Dict[str, Any] = {
'topic': topic, "topic": topic,
} }
if position is not MISSING: if position is not MISSING:
options['position'] = position options["position"] = position
data = await self._create_channel( data = await self._create_channel(
name, overwrites=overwrites, channel_type=ChannelType.stage_voice, category=category, reason=reason, **options name,
overwrites=overwrites,
channel_type=ChannelType.stage_voice,
category=category,
reason=reason,
**options,
) )
channel = StageChannel(state=self._state, guild=self, data=data) channel = StageChannel(state=self._state, guild=self, data=data)
@ -1296,7 +1309,7 @@ class Guild(Hashable):
""" """
options: Dict[str, Any] = {} options: Dict[str, Any] = {}
if position is not MISSING: if position is not MISSING:
options['position'] = position options["position"] = position
data = await self._create_channel( data = await self._create_channel(
name, overwrites=overwrites, channel_type=ChannelType.category, reason=reason, **options name, overwrites=overwrites, channel_type=ChannelType.category, reason=reason, **options
@ -1471,108 +1484,108 @@ class Guild(Hashable):
fields: Dict[str, Any] = {} fields: Dict[str, Any] = {}
if name is not MISSING: if name is not MISSING:
fields['name'] = name fields["name"] = name
if description is not MISSING: if description is not MISSING:
fields['description'] = description fields["description"] = description
if preferred_locale is not MISSING: if preferred_locale is not MISSING:
fields['preferred_locale'] = preferred_locale fields["preferred_locale"] = preferred_locale
if afk_timeout is not MISSING: if afk_timeout is not MISSING:
fields['afk_timeout'] = afk_timeout fields["afk_timeout"] = afk_timeout
if icon is not MISSING: if icon is not MISSING:
if icon is None: if icon is None:
fields['icon'] = icon fields["icon"] = icon
else: else:
fields['icon'] = utils._bytes_to_base64_data(icon) fields["icon"] = utils._bytes_to_base64_data(icon)
if banner is not MISSING: if banner is not MISSING:
if banner is None: if banner is None:
fields['banner'] = banner fields["banner"] = banner
else: else:
fields['banner'] = utils._bytes_to_base64_data(banner) fields["banner"] = utils._bytes_to_base64_data(banner)
if splash is not MISSING: if splash is not MISSING:
if splash is None: if splash is None:
fields['splash'] = splash fields["splash"] = splash
else: else:
fields['splash'] = utils._bytes_to_base64_data(splash) fields["splash"] = utils._bytes_to_base64_data(splash)
if discovery_splash is not MISSING: if discovery_splash is not MISSING:
if discovery_splash is None: if discovery_splash is None:
fields['discovery_splash'] = discovery_splash fields["discovery_splash"] = discovery_splash
else: else:
fields['discovery_splash'] = utils._bytes_to_base64_data(discovery_splash) fields["discovery_splash"] = utils._bytes_to_base64_data(discovery_splash)
if default_notifications is not MISSING: if default_notifications is not MISSING:
if not isinstance(default_notifications, NotificationLevel): if not isinstance(default_notifications, NotificationLevel):
raise InvalidArgument('default_notifications field must be of type NotificationLevel') raise InvalidArgument("default_notifications field must be of type NotificationLevel")
fields['default_message_notifications'] = default_notifications.value fields["default_message_notifications"] = default_notifications.value
if afk_channel is not MISSING: if afk_channel is not MISSING:
if afk_channel is None: if afk_channel is None:
fields['afk_channel_id'] = afk_channel fields["afk_channel_id"] = afk_channel
else: else:
fields['afk_channel_id'] = afk_channel.id fields["afk_channel_id"] = afk_channel.id
if system_channel is not MISSING: if system_channel is not MISSING:
if system_channel is None: if system_channel is None:
fields['system_channel_id'] = system_channel fields["system_channel_id"] = system_channel
else: else:
fields['system_channel_id'] = system_channel.id fields["system_channel_id"] = system_channel.id
if rules_channel is not MISSING: if rules_channel is not MISSING:
if rules_channel is None: if rules_channel is None:
fields['rules_channel_id'] = rules_channel fields["rules_channel_id"] = rules_channel
else: else:
fields['rules_channel_id'] = rules_channel.id fields["rules_channel_id"] = rules_channel.id
if public_updates_channel is not MISSING: if public_updates_channel is not MISSING:
if public_updates_channel is None: if public_updates_channel is None:
fields['public_updates_channel_id'] = public_updates_channel fields["public_updates_channel_id"] = public_updates_channel
else: else:
fields['public_updates_channel_id'] = public_updates_channel.id fields["public_updates_channel_id"] = public_updates_channel.id
if owner is not MISSING: if owner is not MISSING:
if self.owner_id != self._state.self_id: if self.owner_id != self._state.self_id:
raise InvalidArgument('To transfer ownership you must be the owner of the guild.') raise InvalidArgument("To transfer ownership you must be the owner of the guild.")
fields['owner_id'] = owner.id fields["owner_id"] = owner.id
if region is not MISSING: if region is not MISSING:
fields['region'] = str(region) fields["region"] = str(region)
if verification_level is not MISSING: if verification_level is not MISSING:
if not isinstance(verification_level, VerificationLevel): if not isinstance(verification_level, VerificationLevel):
raise InvalidArgument('verification_level field must be of type VerificationLevel') raise InvalidArgument("verification_level field must be of type VerificationLevel")
fields['verification_level'] = verification_level.value fields["verification_level"] = verification_level.value
if explicit_content_filter is not MISSING: if explicit_content_filter is not MISSING:
if not isinstance(explicit_content_filter, ContentFilter): if not isinstance(explicit_content_filter, ContentFilter):
raise InvalidArgument('explicit_content_filter field must be of type ContentFilter') raise InvalidArgument("explicit_content_filter field must be of type ContentFilter")
fields['explicit_content_filter'] = explicit_content_filter.value fields["explicit_content_filter"] = explicit_content_filter.value
if system_channel_flags is not MISSING: if system_channel_flags is not MISSING:
if not isinstance(system_channel_flags, SystemChannelFlags): if not isinstance(system_channel_flags, SystemChannelFlags):
raise InvalidArgument('system_channel_flags field must be of type SystemChannelFlags') raise InvalidArgument("system_channel_flags field must be of type SystemChannelFlags")
fields['system_channel_flags'] = system_channel_flags.value fields["system_channel_flags"] = system_channel_flags.value
if community is not MISSING: if community is not MISSING:
features = [] features = []
if community: if community:
if 'rules_channel_id' in fields and 'public_updates_channel_id' in fields: if "rules_channel_id" in fields and "public_updates_channel_id" in fields:
features.append('COMMUNITY') features.append("COMMUNITY")
else: else:
raise InvalidArgument( raise InvalidArgument(
'community field requires both rules_channel and public_updates_channel fields to be provided' "community field requires both rules_channel and public_updates_channel fields to be provided"
) )
fields['features'] = features fields["features"] = features
data = await http.edit_guild(self.id, reason=reason, **fields) data = await http.edit_guild(self.id, reason=reason, **fields)
return Guild(data=data, state=self._state) return Guild(data=data, state=self._state)
@ -1603,9 +1616,9 @@ class Guild(Hashable):
data = await self._state.http.get_all_guild_channels(self.id) data = await self._state.http.get_all_guild_channels(self.id)
def convert(d): def convert(d):
factory, ch_type = _guild_channel_factory(d['type']) factory, ch_type = _guild_channel_factory(d["type"])
if factory is None: if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(d)) raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d))
channel = factory(guild=self, state=self._state, data=d) channel = factory(guild=self, state=self._state, data=d)
return channel return channel
@ -1632,10 +1645,10 @@ class Guild(Hashable):
The active threads The active threads
""" """
data = await self._state.http.get_active_threads(self.id) data = await self._state.http.get_active_threads(self.id)
threads = [Thread(guild=self, state=self._state, data=d) for d in data.get('threads', [])] threads = [Thread(guild=self, state=self._state, data=d) for d in data.get("threads", [])]
thread_lookup: Dict[int, Thread] = {thread.id: thread for thread in threads} thread_lookup: Dict[int, Thread] = {thread.id: thread for thread in threads}
for member in data.get('members', []): for member in data.get("members", []):
thread = thread_lookup.get(int(member['id'])) thread = thread_lookup.get(int(member["id"]))
if thread is not None: if thread is not None:
thread._add_member(ThreadMember(parent=thread, data=member)) thread._add_member(ThreadMember(parent=thread, data=member))
@ -1691,7 +1704,7 @@ class Guild(Hashable):
""" """
if not self._state._intents.members: if not self._state._intents.members:
raise ClientException('Intents.members must be enabled to use this.') raise ClientException("Intents.members must be enabled to use this.")
return MemberIterator(self, limit=limit, after=after) return MemberIterator(self, limit=limit, after=after)
@ -1715,6 +1728,8 @@ class Guild(Hashable):
You do not have access to the guild. You do not have access to the guild.
HTTPException HTTPException
Fetching the member failed. Fetching the member failed.
NotFound
A member with that ID does not exist.
Returns Returns
-------- --------
@ -1724,6 +1739,34 @@ class Guild(Hashable):
data = await self._state.http.get_member(self.id, member_id) data = await self._state.http.get_member(self.id, member_id)
return Member(data=data, state=self._state, guild=self) return Member(data=data, state=self._state, guild=self)
async def try_member(self, member_id: int, /) -> Optional[Member]:
"""|coro|
Returns a member with the given ID. This uses the cache first, and if not found, it'll request using :meth:`fetch_member`.
.. note::
This method might result in an API call.
Parameters
-----------
member_id: :class:`int`
The ID to search for.
Returns
--------
Optional[:class:`Member`]
The member or ``None`` if not found.
"""
member = self.get_member(member_id)
if member:
return member
else:
try:
return await self.fetch_member(member_id)
except NotFound:
return None
async def fetch_ban(self, user: Snowflake) -> BanEntry: async def fetch_ban(self, user: Snowflake) -> BanEntry:
"""|coro| """|coro|
@ -1752,7 +1795,7 @@ class Guild(Hashable):
The :class:`BanEntry` object for the specified user. The :class:`BanEntry` object for the specified user.
""" """
data: BanPayload = await self._state.http.get_ban(user.id, self.id) data: BanPayload = await self._state.http.get_ban(user.id, self.id)
return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason']) return BanEntry(user=User(state=self._state, data=data["user"]), reason=data["reason"])
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]: async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]:
"""|coro| """|coro|
@ -1785,16 +1828,16 @@ class Guild(Hashable):
""" """
data = await self._state.http.get_channel(channel_id) data = await self._state.http.get_channel(channel_id)
factory, ch_type = _threaded_guild_channel_factory(data['type']) factory, ch_type = _threaded_guild_channel_factory(data["type"])
if factory is None: if factory is None:
raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data))
if ch_type in (ChannelType.group, ChannelType.private): if ch_type in (ChannelType.group, ChannelType.private):
raise InvalidData('Channel ID resolved to a private channel') raise InvalidData("Channel ID resolved to a private channel")
guild_id = int(data['guild_id']) guild_id = int(data["guild_id"])
if self.id != guild_id: if self.id != guild_id:
raise InvalidData('Guild ID resolved to a different guild') raise InvalidData("Guild ID resolved to a different guild")
channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore
return channel return channel
@ -1821,7 +1864,7 @@ class Guild(Hashable):
""" """
data: List[BanPayload] = await self._state.http.get_bans(self.id) data: List[BanPayload] = await self._state.http.get_bans(self.id)
return [BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) for e in data] return [BanEntry(user=User(state=self._state, data=e["user"]), reason=e["reason"]) for e in data]
async def prune_members( async def prune_members(
self, self,
@ -1881,7 +1924,7 @@ class Guild(Hashable):
""" """
if not isinstance(days, int): if not isinstance(days, int):
raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.') raise InvalidArgument(f"Expected int for ``days``, received {days.__class__.__name__} instead.")
if roles: if roles:
role_ids = [str(role.id) for role in roles] role_ids = [str(role.id) for role in roles]
@ -1891,7 +1934,7 @@ class Guild(Hashable):
data = await self._state.http.prune_members( data = await self._state.http.prune_members(
self.id, days, compute_prune_count=compute_prune_count, roles=role_ids, reason=reason self.id, days, compute_prune_count=compute_prune_count, roles=role_ids, reason=reason
) )
return data['pruned'] return data["pruned"]
async def templates(self) -> List[Template]: async def templates(self) -> List[Template]:
"""|coro| """|coro|
@ -1973,7 +2016,7 @@ class Guild(Hashable):
""" """
if not isinstance(days, int): if not isinstance(days, int):
raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.') raise InvalidArgument(f"Expected int for ``days``, received {days.__class__.__name__} instead.")
if roles: if roles:
role_ids = [str(role.id) for role in roles] role_ids = [str(role.id) for role in roles]
@ -1981,7 +2024,7 @@ class Guild(Hashable):
role_ids = [] role_ids = []
data = await self._state.http.estimate_pruned_members(self.id, days, role_ids) data = await self._state.http.estimate_pruned_members(self.id, days, role_ids)
return data['pruned'] return data["pruned"]
async def invites(self) -> List[Invite]: async def invites(self) -> List[Invite]:
"""|coro| """|coro|
@ -2007,7 +2050,7 @@ class Guild(Hashable):
data = await self._state.http.invites_from(self.id) data = await self._state.http.invites_from(self.id)
result = [] result = []
for invite in data: for invite in data:
channel = self.get_channel(int(invite['channel']['id'])) channel = self.get_channel(int(invite["channel"]["id"]))
result.append(Invite(state=self._state, data=invite, guild=self, channel=channel)) result.append(Invite(state=self._state, data=invite, guild=self, channel=channel))
return result return result
@ -2031,10 +2074,10 @@ class Guild(Hashable):
""" """
from .template import Template from .template import Template
payload = {'name': name} payload = {"name": name}
if description: if description:
payload['description'] = description payload["description"] = description
data = await self._state.http.create_template(self.id, payload) data = await self._state.http.create_template(self.id, payload)
@ -2091,9 +2134,9 @@ class Guild(Hashable):
data = await self._state.http.get_all_integrations(self.id) data = await self._state.http.get_all_integrations(self.id)
def convert(d): def convert(d):
factory, _ = _integration_factory(d['type']) factory, _ = _integration_factory(d["type"])
if factory is None: if factory is None:
raise InvalidData('Unknown integration type {type!r} for integration ID {id}'.format_map(d)) raise InvalidData("Unknown integration type {type!r} for integration ID {id}".format_map(d))
return factory(guild=self, data=d) return factory(guild=self, data=d)
return [convert(d) for d in data] return [convert(d) for d in data]
@ -2198,20 +2241,20 @@ class Guild(Hashable):
The created sticker. The created sticker.
""" """
payload = { payload = {
'name': name, "name": name,
} }
if description: if description:
payload['description'] = description payload["description"] = description
try: try:
emoji = unicodedata.name(emoji) emoji = unicodedata.name(emoji)
except TypeError: except TypeError:
pass pass
else: else:
emoji = emoji.replace(' ', '_') emoji = emoji.replace(" ", "_")
payload['tags'] = emoji payload["tags"] = emoji
data = await self._state.http.create_guild_sticker(self.id, payload, file, reason) data = await self._state.http.create_guild_sticker(self.id, payload, file, reason)
return self._state.store_sticker(self, data) return self._state.store_sticker(self, data)
@ -2478,24 +2521,24 @@ class Guild(Hashable):
""" """
fields: Dict[str, Any] = {} fields: Dict[str, Any] = {}
if permissions is not MISSING: if permissions is not MISSING:
fields['permissions'] = str(permissions.value) fields["permissions"] = str(permissions.value)
else: else:
fields['permissions'] = '0' fields["permissions"] = "0"
actual_colour = colour or color or Colour.default() actual_colour = colour or color or Colour.default()
if isinstance(actual_colour, int): if isinstance(actual_colour, int):
fields['color'] = actual_colour fields["color"] = actual_colour
else: else:
fields['color'] = actual_colour.value fields["color"] = actual_colour.value
if hoist is not MISSING: if hoist is not MISSING:
fields['hoist'] = hoist fields["hoist"] = hoist
if mentionable is not MISSING: if mentionable is not MISSING:
fields['mentionable'] = mentionable fields["mentionable"] = mentionable
if name is not MISSING: if name is not MISSING:
fields['name'] = name fields["name"] = name
data = await self._state.http.create_role(self.id, reason=reason, **fields) data = await self._state.http.create_role(self.id, reason=reason, **fields)
role = Role(guild=self, data=data, state=self._state) role = Role(guild=self, data=data, state=self._state)
@ -2548,12 +2591,12 @@ class Guild(Hashable):
A list of all the roles in the guild. A list of all the roles in the guild.
""" """
if not isinstance(positions, dict): if not isinstance(positions, dict):
raise InvalidArgument('positions parameter expects a dict.') raise InvalidArgument("positions parameter expects a dict.")
role_positions: List[Dict[str, Any]] = [] role_positions: List[Dict[str, Any]] = []
for role, position in positions.items(): for role, position in positions.items():
payload = {'id': role.id, 'position': position} payload = {"id": role.id, "position": position}
role_positions.append(payload) role_positions.append(payload)
@ -2679,19 +2722,19 @@ class Guild(Hashable):
# we start with { code: abc } # we start with { code: abc }
payload = await self._state.http.get_vanity_code(self.id) payload = await self._state.http.get_vanity_code(self.id)
if not payload['code']: if not payload["code"]:
return None return None
# get the vanity URL channel since default channels aren't # get the vanity URL channel since default channels aren't
# reliable or a thing anymore # reliable or a thing anymore
data = await self._state.http.get_invite(payload['code']) data = await self._state.http.get_invite(payload["code"])
channel = self.get_channel(int(data['channel']['id'])) channel = self.get_channel(int(data["channel"]["id"]))
payload['revoked'] = False payload["revoked"] = False
payload['temporary'] = False payload["temporary"] = False
payload['max_uses'] = 0 payload["max_uses"] = 0
payload['max_age'] = 0 payload["max_age"] = 0
payload['uses'] = payload.get('uses', 0) payload["uses"] = payload.get("uses", 0)
return Invite(state=self._state, data=payload, guild=self, channel=channel) return Invite(state=self._state, data=payload, guild=self, channel=channel)
# TODO: use MISSING when async iterators get refactored # TODO: use MISSING when async iterators get refactored
@ -2768,7 +2811,13 @@ class Guild(Hashable):
action = action.value action = action.value
return AuditLogIterator( return AuditLogIterator(
self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action self,
before=before,
after=after,
limit=limit,
oldest_first=oldest_first,
user_id=user_id,
action_type=action,
) )
async def widget(self) -> Widget: async def widget(self) -> Widget:
@ -2822,9 +2871,9 @@ class Guild(Hashable):
""" """
payload = {} payload = {}
if channel is not MISSING: if channel is not MISSING:
payload['channel_id'] = None if channel is None else channel.id payload["channel_id"] = None if channel is None else channel.id
if enabled is not MISSING: if enabled is not MISSING:
payload['enabled'] = enabled payload["enabled"] = enabled
await self._state.http.edit_widget(self.id, payload=payload) await self._state.http.edit_widget(self.id, payload=payload)
@ -2850,7 +2899,7 @@ class Guild(Hashable):
""" """
if not self._state._intents.members: if not self._state._intents.members:
raise ClientException('Intents.members must be enabled to use this.') raise ClientException("Intents.members must be enabled to use this.")
if not self._state.is_guild_evicted(self): if not self._state.is_guild_evicted(self):
return await self._state.chunk_guild(self, cache=cache) return await self._state.chunk_guild(self, cache=cache)
@ -2911,20 +2960,20 @@ class Guild(Hashable):
""" """
if presences and not self._state._intents.presences: if presences and not self._state._intents.presences:
raise ClientException('Intents.presences must be enabled to use this.') raise ClientException("Intents.presences must be enabled to use this.")
if query is None: if query is None:
if query == '': if query == "":
raise ValueError('Cannot pass empty query string.') raise ValueError("Cannot pass empty query string.")
if user_ids is None: if user_ids is None:
raise ValueError('Must pass either query or user_ids') raise ValueError("Must pass either query or user_ids")
if user_ids is not None and query is not None: if user_ids is not None and query is not None:
raise ValueError('Cannot pass both query and user_ids') raise ValueError("Cannot pass both query and user_ids")
if user_ids is not None and not user_ids: if user_ids is not None and not user_ids:
raise ValueError('user_ids must contain at least 1 value') raise ValueError("user_ids must contain at least 1 value")
limit = min(100, limit or 5) limit = min(100, limit or 5)
return await self._state.query_members( return await self._state.query_members(

File diff suppressed because it is too large Load Diff

View File

@ -32,11 +32,11 @@ from .errors import InvalidArgument
from .enums import try_enum, ExpireBehaviour from .enums import try_enum, ExpireBehaviour
__all__ = ( __all__ = (
'IntegrationAccount', "IntegrationAccount",
'IntegrationApplication', "IntegrationApplication",
'Integration', "Integration",
'StreamIntegration', "StreamIntegration",
'BotIntegration', "BotIntegration",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -65,14 +65,14 @@ class IntegrationAccount:
The account name. The account name.
""" """
__slots__ = ('id', 'name') __slots__ = ("id", "name")
def __init__(self, data: IntegrationAccountPayload) -> None: def __init__(self, data: IntegrationAccountPayload) -> None:
self.id: str = data['id'] self.id: str = data["id"]
self.name: str = data['name'] self.name: str = data["name"]
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<IntegrationAccount id={self.id} name={self.name!r}>' return f"<IntegrationAccount id={self.id} name={self.name!r}>"
class Integration: class Integration:
@ -99,14 +99,14 @@ class Integration:
""" """
__slots__ = ( __slots__ = (
'guild', "guild",
'id', "id",
'_state', "_state",
'type', "type",
'name', "name",
'account', "account",
'user', "user",
'enabled', "enabled",
) )
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None: def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
@ -118,14 +118,14 @@ class Integration:
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>" return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
def _from_data(self, data: IntegrationPayload) -> None: def _from_data(self, data: IntegrationPayload) -> None:
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.type: IntegrationType = data['type'] self.type: IntegrationType = data["type"]
self.name: str = data['name'] self.name: str = data["name"]
self.account: IntegrationAccount = IntegrationAccount(data['account']) self.account: IntegrationAccount = IntegrationAccount(data["account"])
user = data.get('user') user = data.get("user")
self.user = User(state=self._state, data=user) if user else None self.user = User(state=self._state, data=user) if user else None
self.enabled: bool = data['enabled'] self.enabled: bool = data["enabled"]
async def delete(self, *, reason: Optional[str] = None) -> None: async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro| """|coro|
@ -186,26 +186,26 @@ class StreamIntegration(Integration):
""" """
__slots__ = ( __slots__ = (
'revoked', "revoked",
'expire_behaviour', "expire_behaviour",
'expire_grace_period', "expire_grace_period",
'synced_at', "synced_at",
'_role_id', "_role_id",
'syncing', "syncing",
'enable_emoticons', "enable_emoticons",
'subscriber_count', "subscriber_count",
) )
def _from_data(self, data: StreamIntegrationPayload) -> None: def _from_data(self, data: StreamIntegrationPayload) -> None:
super()._from_data(data) super()._from_data(data)
self.revoked: bool = data['revoked'] self.revoked: bool = data["revoked"]
self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data['expire_behavior']) self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data["expire_behavior"])
self.expire_grace_period: int = data['expire_grace_period'] self.expire_grace_period: int = data["expire_grace_period"]
self.synced_at: datetime.datetime = parse_time(data['synced_at']) self.synced_at: datetime.datetime = parse_time(data["synced_at"])
self._role_id: Optional[int] = _get_as_snowflake(data, 'role_id') self._role_id: Optional[int] = _get_as_snowflake(data, "role_id")
self.syncing: bool = data['syncing'] self.syncing: bool = data["syncing"]
self.enable_emoticons: bool = data['enable_emoticons'] self.enable_emoticons: bool = data["enable_emoticons"]
self.subscriber_count: int = data['subscriber_count'] self.subscriber_count: int = data["subscriber_count"]
@property @property
def expire_behavior(self) -> ExpireBehaviour: def expire_behavior(self) -> ExpireBehaviour:
@ -252,15 +252,15 @@ class StreamIntegration(Integration):
payload: Dict[str, Any] = {} payload: Dict[str, Any] = {}
if expire_behaviour is not MISSING: if expire_behaviour is not MISSING:
if not isinstance(expire_behaviour, ExpireBehaviour): if not isinstance(expire_behaviour, ExpireBehaviour):
raise InvalidArgument('expire_behaviour field must be of type ExpireBehaviour') raise InvalidArgument("expire_behaviour field must be of type ExpireBehaviour")
payload['expire_behavior'] = expire_behaviour.value payload["expire_behavior"] = expire_behaviour.value
if expire_grace_period is not MISSING: if expire_grace_period is not MISSING:
payload['expire_grace_period'] = expire_grace_period payload["expire_grace_period"] = expire_grace_period
if enable_emoticons is not MISSING: if enable_emoticons is not MISSING:
payload['enable_emoticons'] = enable_emoticons payload["enable_emoticons"] = enable_emoticons
# This endpoint is undocumented. # This endpoint is undocumented.
# Unsure if it returns the data or not as a result # Unsure if it returns the data or not as a result
@ -307,21 +307,21 @@ class IntegrationApplication:
""" """
__slots__ = ( __slots__ = (
'id', "id",
'name', "name",
'icon', "icon",
'description', "description",
'summary', "summary",
'user', "user",
) )
def __init__(self, *, data: IntegrationApplicationPayload, state): def __init__(self, *, data: IntegrationApplicationPayload, state):
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.name: str = data['name'] self.name: str = data["name"]
self.icon: Optional[str] = data['icon'] self.icon: Optional[str] = data["icon"]
self.description: str = data['description'] self.description: str = data["description"]
self.summary: str = data['summary'] self.summary: str = data["summary"]
user = data.get('bot') user = data.get("bot")
self.user: Optional[User] = User(state=state, data=user) if user else None self.user: Optional[User] = User(state=state, data=user) if user else None
@ -350,17 +350,17 @@ class BotIntegration(Integration):
The application tied to this integration. The application tied to this integration.
""" """
__slots__ = ('application',) __slots__ = ("application",)
def _from_data(self, data: BotIntegrationPayload) -> None: def _from_data(self, data: BotIntegrationPayload) -> None:
super()._from_data(data) super()._from_data(data)
self.application = IntegrationApplication(data=data['application'], state=self._state) self.application = IntegrationApplication(data=data["application"], state=self._state)
def _integration_factory(value: str) -> Tuple[Type[Integration], str]: def _integration_factory(value: str) -> Tuple[Type[Integration], str]:
if value == 'discord': if value == "discord":
return BotIntegration, value return BotIntegration, value
elif value in ('twitch', 'youtube'): elif value in ("twitch", "youtube"):
return StreamIntegration, value return StreamIntegration, value
else: else:
return Integration, value return Integration, value

View File

@ -41,14 +41,17 @@ from .permissions import Permissions
from .webhook.async_ import async_context, Webhook, handle_message_parameters from .webhook.async_ import async_context, Webhook, handle_message_parameters
__all__ = ( __all__ = (
'Interaction', "Interaction",
'InteractionMessage', "InteractionMessage",
'InteractionResponse', "InteractionResponse",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from datetime import datetime
from .types.interactions import ( from .types.interactions import (
Interaction as InteractionPayload, Interaction as InteractionPayload,
ApplicationCommandOptionChoice,
InteractionData, InteractionData,
) )
from .guild import Guild from .guild import Guild
@ -58,12 +61,10 @@ if TYPE_CHECKING:
from aiohttp import ClientSession from aiohttp import ClientSession
from .embeds import Embed from .embeds import Embed
from .ui.view import View from .ui.view import View
from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable from .channel import TextChannel, CategoryChannel, StoreChannel, PartialMessageable
from .threads import Thread from .threads import Thread
InteractionChannel = Union[ InteractionChannel = Union[TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable]
VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable
]
MISSING: Any = utils.MISSING MISSING: Any = utils.MISSING
@ -100,23 +101,23 @@ class Interaction:
""" """
__slots__: Tuple[str, ...] = ( __slots__: Tuple[str, ...] = (
'id', "id",
'type', "type",
'guild_id', "guild_id",
'channel_id', "channel_id",
'data', "data",
'application_id', "application_id",
'message', "message",
'user', "user",
'token', "token",
'version', "version",
'_permissions', "_permissions",
'_state', "_state",
'_session', "_session",
'_original_message', "_original_message",
'_cs_response', "_cs_response",
'_cs_followup', "_cs_followup",
'_cs_channel', "_cs_channel",
) )
def __init__(self, *, data: InteractionPayload, state: ConnectionState): def __init__(self, *, data: InteractionPayload, state: ConnectionState):
@ -126,18 +127,18 @@ class Interaction:
self._from_data(data) self._from_data(data)
def _from_data(self, data: InteractionPayload): def _from_data(self, data: InteractionPayload):
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.type: InteractionType = try_enum(InteractionType, data['type']) self.type: InteractionType = try_enum(InteractionType, data["type"])
self.data: Optional[InteractionData] = data.get('data') self.data: Optional[InteractionData] = data.get("data")
self.token: str = data['token'] self.token: str = data["token"]
self.version: int = data['version'] self.version: int = data["version"]
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id') self.channel_id: Optional[int] = utils._get_as_snowflake(data, "channel_id")
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
self.application_id: int = int(data['application_id']) self.application_id: int = int(data["application_id"])
self.message: Optional[Message] self.message: Optional[Message]
try: try:
self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore
except KeyError: except KeyError:
self.message = None self.message = None
@ -148,15 +149,15 @@ class Interaction:
if self.guild_id: if self.guild_id:
guild = self.guild or Object(id=self.guild_id) guild = self.guild or Object(id=self.guild_id)
try: try:
member = data['member'] # type: ignore member = data["member"] # type: ignore
except KeyError: except KeyError:
pass pass
else: else:
self.user = Member(state=self._state, guild=guild, data=member) # type: ignore self.user = Member(state=self._state, guild=guild, data=member) # type: ignore
self._permissions = int(member.get('permissions', 0)) self._permissions = int(member.get("permissions", 0))
else: else:
try: try:
self.user = User(state=self._state, data=data['user']) self.user = User(state=self._state, data=data["user"])
except KeyError: except KeyError:
pass pass
@ -165,7 +166,7 @@ class Interaction:
"""Optional[:class:`Guild`]: The guild the interaction was sent from.""" """Optional[:class:`Guild`]: The guild the interaction was sent from."""
return self._state and self._state._get_guild(self.guild_id) return self._state and self._state._get_guild(self.guild_id)
@utils.cached_slot_property('_cs_channel') @utils.cached_slot_property("_cs_channel")
def channel(self) -> Optional[InteractionChannel]: def channel(self) -> Optional[InteractionChannel]:
"""Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from. """Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from.
@ -179,7 +180,7 @@ class Interaction:
type = ChannelType.text if self.guild_id is not None else ChannelType.private type = ChannelType.text if self.guild_id is not None else ChannelType.private
return PartialMessageable(state=self._state, id=self.channel_id, type=type) return PartialMessageable(state=self._state, id=self.channel_id, type=type)
return None return None
return channel return channel # type: ignore
@property @property
def permissions(self) -> Permissions: def permissions(self) -> Permissions:
@ -189,7 +190,7 @@ class Interaction:
""" """
return Permissions(self._permissions) return Permissions(self._permissions)
@utils.cached_slot_property('_cs_response') @utils.cached_slot_property("_cs_response")
def response(self) -> InteractionResponse: def response(self) -> InteractionResponse:
""":class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction. """:class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction.
@ -198,13 +199,13 @@ class Interaction:
""" """
return InteractionResponse(self) return InteractionResponse(self)
@utils.cached_slot_property('_cs_followup') @utils.cached_slot_property("_cs_followup")
def followup(self) -> Webhook: def followup(self) -> Webhook:
""":class:`Webhook`: Returns the follow up webhook for follow up interactions.""" """:class:`Webhook`: Returns the follow up webhook for follow up interactions."""
payload = { payload = {
'id': self.application_id, "id": self.application_id,
'type': 3, "type": 3,
'token': self.token, "token": self.token,
} }
return Webhook.from_state(data=payload, state=self._state) return Webhook.from_state(data=payload, state=self._state)
@ -238,7 +239,7 @@ class Interaction:
# TODO: fix later to not raise? # TODO: fix later to not raise?
channel = self.channel channel = self.channel
if channel is None: if channel is None:
raise ClientException('Channel for message could not be resolved') raise ClientException("Channel for message could not be resolved")
adapter = async_context.get() adapter = async_context.get()
data = await adapter.get_original_interaction_response( data = await adapter.get_original_interaction_response(
@ -369,20 +370,20 @@ class InteractionResponse:
""" """
__slots__: Tuple[str, ...] = ( __slots__: Tuple[str, ...] = (
'_responded', "responded_at",
'_parent', "_parent",
) )
def __init__(self, parent: Interaction): def __init__(self, parent: Interaction):
self.responded_at: Optional[datetime] = None
self._parent: Interaction = parent self._parent: Interaction = parent
self._responded: bool = False
def is_done(self) -> bool: def is_done(self) -> bool:
""":class:`bool`: Indicates whether an interaction response has been done before. """:class:`bool`: Indicates whether an interaction response has been done before.
An interaction can only be responded to once. An interaction can only be responded to once.
""" """
return self._responded return self.responded_at is not None
async def defer(self, *, ephemeral: bool = False) -> None: async def defer(self, *, ephemeral: bool = False) -> None:
"""|coro| """|coro|
@ -405,7 +406,7 @@ class InteractionResponse:
InteractionResponded InteractionResponded
This interaction has already been responded to before. This interaction has already been responded to before.
""" """
if self._responded: if self.is_done():
raise InteractionResponded(self._parent) raise InteractionResponded(self._parent)
defer_type: int = 0 defer_type: int = 0
@ -416,14 +417,15 @@ class InteractionResponse:
elif parent.type is InteractionType.application_command: elif parent.type is InteractionType.application_command:
defer_type = InteractionResponseType.deferred_channel_message.value defer_type = InteractionResponseType.deferred_channel_message.value
if ephemeral: if ephemeral:
data = {'flags': 64} data = {"flags": 64}
if defer_type: if defer_type:
adapter = async_context.get() adapter = async_context.get()
await adapter.create_interaction_response( await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=defer_type, data=data parent.id, parent.token, session=parent._session, type=defer_type, data=data
) )
self._responded = True
self.responded_at = utils.utcnow()
async def pong(self) -> None: async def pong(self) -> None:
"""|coro| """|coro|
@ -439,7 +441,7 @@ class InteractionResponse:
InteractionResponded InteractionResponded
This interaction has already been responded to before. This interaction has already been responded to before.
""" """
if self._responded: if self.is_done():
raise InteractionResponded(self._parent) raise InteractionResponded(self._parent)
parent = self._parent parent = self._parent
@ -448,7 +450,7 @@ class InteractionResponse:
await adapter.create_interaction_response( await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value
) )
self._responded = True self.responded_at = utils.utcnow()
async def send_message( async def send_message(
self, self,
@ -494,32 +496,32 @@ class InteractionResponse:
InteractionResponded InteractionResponded
This interaction has already been responded to before. This interaction has already been responded to before.
""" """
if self._responded: if self.is_done():
raise InteractionResponded(self._parent) raise InteractionResponded(self._parent)
payload: Dict[str, Any] = { payload: Dict[str, Any] = {
'tts': tts, "tts": tts,
} }
if embed is not MISSING and embeds is not MISSING: if embed is not MISSING and embeds is not MISSING:
raise TypeError('cannot mix embed and embeds keyword arguments') raise TypeError("cannot mix embed and embeds keyword arguments")
if embed is not MISSING: if embed is not MISSING:
embeds = [embed] embeds = [embed]
if embeds: if embeds:
if len(embeds) > 10: if len(embeds) > 10:
raise ValueError('embeds cannot exceed maximum of 10 elements') raise ValueError("embeds cannot exceed maximum of 10 elements")
payload['embeds'] = [e.to_dict() for e in embeds] payload["embeds"] = [e.to_dict() for e in embeds]
if content is not None: if content is not None:
payload['content'] = str(content) payload["content"] = str(content)
if ephemeral: if ephemeral:
payload['flags'] = 64 payload["flags"] = 64
if view is not MISSING: if view is not MISSING:
payload['components'] = view.to_components() payload["components"] = view.to_components()
parent = self._parent parent = self._parent
adapter = async_context.get() adapter = async_context.get()
@ -537,7 +539,7 @@ class InteractionResponse:
self._parent._state.store_view(view) self._parent._state.store_view(view)
self._responded = True self.responded_at = utils.utcnow()
async def edit_message( async def edit_message(
self, self,
@ -578,7 +580,7 @@ class InteractionResponse:
InteractionResponded InteractionResponded
This interaction has already been responded to before. This interaction has already been responded to before.
""" """
if self._responded: if self.is_done():
raise InteractionResponded(self._parent) raise InteractionResponded(self._parent)
parent = self._parent parent = self._parent
@ -591,12 +593,12 @@ class InteractionResponse:
payload = {} payload = {}
if content is not MISSING: if content is not MISSING:
if content is None: if content is None:
payload['content'] = None payload["content"] = None
else: else:
payload['content'] = str(content) payload["content"] = str(content)
if embed is not MISSING and embeds is not MISSING: if embed is not MISSING and embeds is not MISSING:
raise TypeError('cannot mix both embed and embeds keyword arguments') raise TypeError("cannot mix both embed and embeds keyword arguments")
if embed is not MISSING: if embed is not MISSING:
if embed is None: if embed is None:
@ -605,17 +607,17 @@ class InteractionResponse:
embeds = [embed] embeds = [embed]
if embeds is not MISSING: if embeds is not MISSING:
payload['embeds'] = [e.to_dict() for e in embeds] payload["embeds"] = [e.to_dict() for e in embeds]
if attachments is not MISSING: if attachments is not MISSING:
payload['attachments'] = [a.to_dict() for a in attachments] payload["attachments"] = [a.to_dict() for a in attachments]
if view is not MISSING: if view is not MISSING:
state.prevent_view_updates_for(message_id) state.prevent_view_updates_for(message_id)
if view is None: if view is None:
payload['components'] = [] payload["components"] = []
else: else:
payload['components'] = view.to_components() payload["components"] = view.to_components()
adapter = async_context.get() adapter = async_context.get()
await adapter.create_interaction_response( await adapter.create_interaction_response(
@ -629,11 +631,48 @@ class InteractionResponse:
if view and not view.is_finished(): if view and not view.is_finished():
state.store_view(view, message_id) state.store_view(view, message_id)
self._responded = True self.responded_at = utils.utcnow()
async def autocomplete_result(self, choices: List[ApplicationCommandOptionChoice]):
"""|coro|
Responds to this autocomplete interaction with the resulting choices.
This should rarely be used.
Parameters
-----------
choices: List[Dict[:class:`str`, :class:`str`]]
The choices to be shown in the autocomplete UI of the user.
Must be a list of dictionaries with the ``name`` and ``value`` keys.
Raises
-------
HTTPException
Responding to the interaction failed.
InteractionResponded
This interaction has already been responded to before.
"""
if self.is_done():
raise InteractionResponded(self._parent)
parent = self._parent
if parent.type is not InteractionType.application_command_autocomplete:
return
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.application_command_autocomplete_result.value,
data={"choices": choices},
)
self.responded_at = utils.utcnow()
class _InteractionMessageState: class _InteractionMessageState:
__slots__ = ('_parent', '_interaction') __slots__ = ("_parent", "_interaction")
def __init__(self, interaction: Interaction, parent: ConnectionState): def __init__(self, interaction: Interaction, parent: ConnectionState):
self._interaction: Interaction = interaction self._interaction: Interaction = interaction

View File

@ -33,9 +33,9 @@ from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum
from .appinfo import PartialAppInfo from .appinfo import PartialAppInfo
__all__ = ( __all__ = (
'PartialInviteChannel', "PartialInviteChannel",
'PartialInviteGuild', "PartialInviteGuild",
'Invite', "Invite",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -52,8 +52,8 @@ if TYPE_CHECKING:
from .abc import GuildChannel from .abc import GuildChannel
from .user import User from .user import User
InviteGuildType = Union[Guild, 'PartialInviteGuild', Object] InviteGuildType = Union[Guild, "PartialInviteGuild", Object]
InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object] InviteChannelType = Union[GuildChannel, "PartialInviteChannel", Object]
import datetime import datetime
@ -92,23 +92,23 @@ class PartialInviteChannel:
The partial channel's type. The partial channel's type.
""" """
__slots__ = ('id', 'name', 'type') __slots__ = ("id", "name", "type")
def __init__(self, data: InviteChannelPayload): def __init__(self, data: InviteChannelPayload):
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.name: str = data['name'] self.name: str = data["name"]
self.type: ChannelType = try_enum(ChannelType, data['type']) self.type: ChannelType = try_enum(ChannelType, data["type"])
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>' return f"<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>"
@property @property
def mention(self) -> str: def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel.""" """:class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>' return f"<#{self.id}>"
@property @property
def created_at(self) -> datetime.datetime: def created_at(self) -> datetime.datetime:
@ -154,26 +154,26 @@ class PartialInviteGuild:
The partial guild's description. The partial guild's description.
""" """
__slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description') __slots__ = ("_state", "features", "_icon", "_banner", "id", "name", "_splash", "verification_level", "description")
def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int): def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int):
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = id self.id: int = id
self.name: str = data['name'] self.name: str = data["name"]
self.features: List[str] = data.get('features', []) self.features: List[str] = data.get("features", [])
self._icon: Optional[str] = data.get('icon') self._icon: Optional[str] = data.get("icon")
self._banner: Optional[str] = data.get('banner') self._banner: Optional[str] = data.get("banner")
self._splash: Optional[str] = data.get('splash') self._splash: Optional[str] = data.get("splash")
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level')) self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get("verification_level"))
self.description: Optional[str] = data.get('description') self.description: Optional[str] = data.get("description")
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} ' f"<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} "
f'description={self.description!r}>' f"description={self.description!r}>"
) )
@property @property
@ -193,17 +193,17 @@ class PartialInviteGuild:
"""Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" """Optional[:class:`Asset`]: Returns the guild's banner asset, if available."""
if self._banner is None: if self._banner is None:
return None return None
return Asset._from_guild_image(self._state, self.id, self._banner, path='banners') return Asset._from_guild_image(self._state, self.id, self._banner, path="banners")
@property @property
def splash(self) -> Optional[Asset]: def splash(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available."""
if self._splash is None: if self._splash is None:
return None return None
return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes') return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes")
I = TypeVar('I', bound='Invite') I = TypeVar("I", bound="Invite")
class Invite(Hashable): class Invite(Hashable):
@ -230,6 +230,7 @@ class Invite(Hashable):
Returns the invite URL. Returns the invite URL.
The following table illustrates what methods will obtain the attributes: The following table illustrates what methods will obtain the attributes:
+------------------------------------+------------------------------------------------------------+ +------------------------------------+------------------------------------------------------------+
@ -307,26 +308,26 @@ class Invite(Hashable):
""" """
__slots__ = ( __slots__ = (
'max_age', "max_age",
'code', "code",
'guild', "guild",
'revoked', "revoked",
'created_at', "created_at",
'uses', "uses",
'temporary', "temporary",
'max_uses', "max_uses",
'inviter', "inviter",
'channel', "channel",
'target_user', "target_user",
'target_type', "target_type",
'_state', "_state",
'approximate_member_count', "approximate_member_count",
'approximate_presence_count', "approximate_presence_count",
'target_application', "target_application",
'expires_at', "expires_at",
) )
BASE = 'https://discord.gg' BASE = "https://discord.gg"
def __init__( def __init__(
self, self,
@ -337,31 +338,33 @@ class Invite(Hashable):
channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None, channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None,
): ):
self._state: ConnectionState = state self._state: ConnectionState = state
self.max_age: Optional[int] = data.get('max_age') self.max_age: Optional[int] = data.get("max_age")
self.code: str = data['code'] self.code: str = data["code"]
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild) self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get("guild"), guild)
self.revoked: Optional[bool] = data.get('revoked') self.revoked: Optional[bool] = data.get("revoked")
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at')) self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at"))
self.temporary: Optional[bool] = data.get('temporary') self.temporary: Optional[bool] = data.get("temporary")
self.uses: Optional[int] = data.get('uses') self.uses: Optional[int] = data.get("uses")
self.max_uses: Optional[int] = data.get('max_uses') self.max_uses: Optional[int] = data.get("max_uses")
self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count') self.approximate_presence_count: Optional[int] = data.get("approximate_presence_count")
self.approximate_member_count: Optional[int] = data.get('approximate_member_count') self.approximate_member_count: Optional[int] = data.get("approximate_member_count")
expires_at = data.get('expires_at', None) expires_at = data.get("expires_at", None)
self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None
inviter_data = data.get('inviter') inviter_data = data.get("inviter")
self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data) self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data)
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel) self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get("channel"), channel)
target_user_data = data.get('target_user') target_user_data = data.get("target_user")
self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data) self.target_user: Optional[User] = (
None if target_user_data is None else self._state.create_user(target_user_data)
)
self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0)) self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0))
application = data.get('target_application') application = data.get("target_application")
self.target_application: Optional[PartialAppInfo] = ( self.target_application: Optional[PartialAppInfo] = (
PartialAppInfo(data=application, state=state) if application else None PartialAppInfo(data=application, state=state) if application else None
) )
@ -370,12 +373,12 @@ class Invite(Hashable):
def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I: def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I:
guild: Optional[Union[Guild, PartialInviteGuild]] guild: Optional[Union[Guild, PartialInviteGuild]]
try: try:
guild_data = data['guild'] guild_data = data["guild"]
except KeyError: except KeyError:
# If we're here, then this is a group DM # If we're here, then this is a group DM
guild = None guild = None
else: else:
guild_id = int(guild_data['id']) guild_id = int(guild_data["id"])
guild = state._get_guild(guild_id) guild = state._get_guild(guild_id)
if guild is None: if guild is None:
# If it's not cached, then it has to be a partial guild # If it's not cached, then it has to be a partial guild
@ -383,7 +386,7 @@ class Invite(Hashable):
# As far as I know, invites always need a channel # As far as I know, invites always need a channel
# So this should never raise. # So this should never raise.
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel']) channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data["channel"])
if guild is not None and not isinstance(guild, PartialInviteGuild): if guild is not None and not isinstance(guild, PartialInviteGuild):
# Upgrade the partial data if applicable # Upgrade the partial data if applicable
channel = guild.get_channel(channel.id) or channel channel = guild.get_channel(channel.id) or channel
@ -392,9 +395,9 @@ class Invite(Hashable):
@classmethod @classmethod
def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I:
guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') guild_id: Optional[int] = _get_as_snowflake(data, "guild_id")
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
channel_id = int(data['channel_id']) channel_id = int(data["channel_id"])
if guild is not None: if guild is not None:
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore
else: else:
@ -414,7 +417,7 @@ class Invite(Hashable):
if data is None: if data is None:
return None return None
guild_id = int(data['id']) guild_id = int(data["id"])
return PartialInviteGuild(self._state, data, guild_id) return PartialInviteGuild(self._state, data, guild_id)
def _resolve_channel( def _resolve_channel(
@ -433,11 +436,14 @@ class Invite(Hashable):
def __str__(self) -> str: def __str__(self) -> str:
return self.url return self.url
def __int__(self) -> int:
return 0 # To keep the object compatible with the hashable abc.
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<Invite code={self.code!r} guild={self.guild!r} ' f"<Invite code={self.code!r} guild={self.guild!r} "
f'online={self.approximate_presence_count} ' f"online={self.approximate_presence_count} "
f'members={self.approximate_member_count}>' f"members={self.approximate_member_count}>"
) )
def __hash__(self) -> int: def __hash__(self) -> int:
@ -451,7 +457,7 @@ class Invite(Hashable):
@property @property
def url(self) -> str: def url(self) -> str:
""":class:`str`: A property that retrieves the invite URL.""" """:class:`str`: A property that retrieves the invite URL."""
return self.BASE + '/' + self.code return self.BASE + "/" + self.code
async def delete(self, *, reason: Optional[str] = None): async def delete(self, *, reason: Optional[str] = None):
"""|coro| """|coro|

View File

@ -34,11 +34,11 @@ from .object import Object
from .audit_logs import AuditLogEntry from .audit_logs import AuditLogEntry
__all__ = ( __all__ = (
'ReactionIterator', "ReactionIterator",
'HistoryIterator', "HistoryIterator",
'AuditLogIterator', "AuditLogIterator",
'GuildIterator', "GuildIterator",
'MemberIterator', "MemberIterator",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -67,8 +67,8 @@ if TYPE_CHECKING:
from .threads import Thread from .threads import Thread
from .abc import Snowflake from .abc import Snowflake
T = TypeVar('T') T = TypeVar("T")
OT = TypeVar('OT') OT = TypeVar("OT")
_Func = Callable[[T], Union[OT, Awaitable[OT]]] _Func = Callable[[T], Union[OT, Awaitable[OT]]]
OLDEST_OBJECT = Object(id=0) OLDEST_OBJECT = Object(id=0)
@ -83,7 +83,7 @@ class _AsyncIterator(AsyncIterator[T]):
def get(self, **attrs: Any) -> Awaitable[Optional[T]]: def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
def predicate(elem: T): def predicate(elem: T):
for attr, val in attrs.items(): for attr, val in attrs.items():
nested = attr.split('__') nested = attr.split("__")
obj = elem obj = elem
for attribute in nested: for attribute in nested:
obj = getattr(obj, attribute) obj = getattr(obj, attribute)
@ -107,7 +107,7 @@ class _AsyncIterator(AsyncIterator[T]):
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
if max_size <= 0: if max_size <= 0:
raise ValueError('async iterator chunk sizes must be greater than 0.') raise ValueError("async iterator chunk sizes must be greater than 0.")
return _ChunkedAsyncIterator(self, max_size) return _ChunkedAsyncIterator(self, max_size)
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
@ -182,7 +182,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
return item return item
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): class ReactionIterator(_AsyncIterator[Union["User", "Member"]]):
def __init__(self, message, emoji, limit=100, after=None): def __init__(self, message, emoji, limit=100, after=None):
self.message = message self.message = message
self.limit = limit self.limit = limit
@ -218,14 +218,14 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
if data: if data:
self.limit -= retrieve self.limit -= retrieve
self.after = Object(id=int(data[-1]['id'])) self.after = Object(id=int(data[-1]["id"]))
if self.guild is None or isinstance(self.guild, Object): if self.guild is None or isinstance(self.guild, Object):
for element in reversed(data): for element in reversed(data):
await self.users.put(User(state=self.state, data=element)) await self.users.put(User(state=self.state, data=element))
else: else:
for element in reversed(data): for element in reversed(data):
member_id = int(element['id']) member_id = int(element["id"])
member = self.guild.get_member(member_id) member = self.guild.get_member(member_id)
if member is not None: if member is not None:
await self.users.put(member) await self.users.put(member)
@ -233,7 +233,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
await self.users.put(User(state=self.state, data=element)) await self.users.put(User(state=self.state, data=element))
class HistoryIterator(_AsyncIterator['Message']): class HistoryIterator(_AsyncIterator["Message"]):
"""Iterator for receiving a channel's message history. """Iterator for receiving a channel's message history.
The messages endpoint has two behaviours we care about here: The messages endpoint has two behaviours we care about here:
@ -295,7 +295,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if self.around: if self.around:
if self.limit is None: if self.limit is None:
raise ValueError('history does not support around with limit=None') raise ValueError("history does not support around with limit=None")
if self.limit > 101: if self.limit > 101:
raise ValueError("history max limit 101 when specifying around parameter") raise ValueError("history max limit 101 when specifying around parameter")
elif self.limit == 101: elif self.limit == 101:
@ -303,20 +303,20 @@ class HistoryIterator(_AsyncIterator['Message']):
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
if self.before and self.after: if self.before and self.after:
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id self._filter = lambda m: self.after.id < int(m["id"]) < self.before.id
elif self.before: elif self.before:
self._filter = lambda m: int(m['id']) < self.before.id self._filter = lambda m: int(m["id"]) < self.before.id
elif self.after: elif self.after:
self._filter = lambda m: self.after.id < int(m['id']) self._filter = lambda m: self.after.id < int(m["id"])
else: else:
if self.reverse: if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
if self.before: if self.before:
self._filter = lambda m: int(m['id']) < self.before.id self._filter = lambda m: int(m["id"]) < self.before.id
else: else:
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
if self.after and self.after != OLDEST_OBJECT: if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id self._filter = lambda m: int(m["id"]) > self.after.id
async def next(self) -> Message: async def next(self) -> Message:
if self.messages.empty(): if self.messages.empty():
@ -337,7 +337,7 @@ class HistoryIterator(_AsyncIterator['Message']):
return r > 0 return r > 0
async def fill_messages(self): async def fill_messages(self):
if not hasattr(self, 'channel'): if not hasattr(self, "channel"):
# do the required set up # do the required set up
channel = await self.messageable._get_channel() channel = await self.messageable._get_channel()
self.channel = channel self.channel = channel
@ -367,7 +367,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if len(data): if len(data):
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
self.before = Object(id=int(data[-1]['id'])) self.before = Object(id=int(data[-1]["id"]))
return data return data
async def _retrieve_messages_after_strategy(self, retrieve): async def _retrieve_messages_after_strategy(self, retrieve):
@ -377,7 +377,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if len(data): if len(data):
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
self.after = Object(id=int(data[0]['id'])) self.after = Object(id=int(data[0]["id"]))
return data return data
async def _retrieve_messages_around_strategy(self, retrieve): async def _retrieve_messages_around_strategy(self, retrieve):
@ -390,7 +390,7 @@ class HistoryIterator(_AsyncIterator['Message']):
return [] return []
class AuditLogIterator(_AsyncIterator['AuditLogEntry']): class AuditLogIterator(_AsyncIterator["AuditLogEntry"]):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
if isinstance(before, datetime.datetime): if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False)) before = Object(id=time_snowflake(before, high=False))
@ -420,11 +420,11 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
if self.reverse: if self.reverse:
self._strategy = self._after_strategy self._strategy = self._after_strategy
if self.before: if self.before:
self._filter = lambda m: int(m['id']) < self.before.id self._filter = lambda m: int(m["id"]) < self.before.id
else: else:
self._strategy = self._before_strategy self._strategy = self._before_strategy
if self.after and self.after != OLDEST_OBJECT: if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m['id']) > self.after.id self._filter = lambda m: int(m["id"]) > self.after.id
async def _before_strategy(self, retrieve): async def _before_strategy(self, retrieve):
before = self.before.id if self.before else None before = self.before.id if self.before else None
@ -432,24 +432,24 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
) )
entries = data.get('audit_log_entries', []) entries = data.get("audit_log_entries", [])
if len(data) and entries: if len(data) and entries:
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
self.before = Object(id=int(entries[-1]['id'])) self.before = Object(id=int(entries[-1]["id"]))
return data.get('users', []), entries return data.get("users", []), entries
async def _after_strategy(self, retrieve): async def _after_strategy(self, retrieve):
after = self.after.id if self.after else None after = self.after.id if self.after else None
data: AuditLogPayload = await self.request( data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
) )
entries = data.get('audit_log_entries', []) entries = data.get("audit_log_entries", [])
if len(data) and entries: if len(data) and entries:
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
self.after = Object(id=int(entries[0]['id'])) self.after = Object(id=int(entries[0]["id"]))
return data.get('users', []), entries return data.get("users", []), entries
async def next(self) -> AuditLogEntry: async def next(self) -> AuditLogEntry:
if self.entries.empty(): if self.entries.empty():
@ -488,13 +488,13 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
for element in data: for element in data:
# TODO: remove this if statement later # TODO: remove this if statement later
if element['action_type'] is None: if element["action_type"] is None:
continue continue
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
class GuildIterator(_AsyncIterator['Guild']): class GuildIterator(_AsyncIterator["Guild"]):
"""Iterator for receiving the client's guilds. """Iterator for receiving the client's guilds.
The guilds endpoint has the same two behaviours as described The guilds endpoint has the same two behaviours as described
@ -543,7 +543,7 @@ class GuildIterator(_AsyncIterator['Guild']):
if self.before and self.after: if self.before and self.after:
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
self._filter = lambda m: int(m['id']) > self.after.id self._filter = lambda m: int(m["id"]) > self.after.id
elif self.after: elif self.after:
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
else: else:
@ -595,7 +595,7 @@ class GuildIterator(_AsyncIterator['Guild']):
if len(data): if len(data):
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
self.before = Object(id=int(data[-1]['id'])) self.before = Object(id=int(data[-1]["id"]))
return data return data
async def _retrieve_guilds_after_strategy(self, retrieve): async def _retrieve_guilds_after_strategy(self, retrieve):
@ -605,11 +605,11 @@ class GuildIterator(_AsyncIterator['Guild']):
if len(data): if len(data):
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
self.after = Object(id=int(data[0]['id'])) self.after = Object(id=int(data[0]["id"]))
return data return data
class MemberIterator(_AsyncIterator['Member']): class MemberIterator(_AsyncIterator["Member"]):
def __init__(self, guild, limit=1000, after=None): def __init__(self, guild, limit=1000, after=None):
if isinstance(after, datetime.datetime): if isinstance(after, datetime.datetime):
@ -652,7 +652,7 @@ class MemberIterator(_AsyncIterator['Member']):
if len(data) < 1000: if len(data) < 1000:
self.limit = 0 # terminate loop self.limit = 0 # terminate loop
self.after = Object(id=int(data[-1]['user']['id'])) self.after = Object(id=int(data[-1]["user"]["id"]))
for element in reversed(data): for element in reversed(data):
await self.members.put(self.create_member(element)) await self.members.put(self.create_member(element))
@ -663,7 +663,7 @@ class MemberIterator(_AsyncIterator['Member']):
return Member(data=data, guild=self.guild, state=self.state) return Member(data=data, guild=self.guild, state=self.state)
class ArchivedThreadIterator(_AsyncIterator['Thread']): class ArchivedThreadIterator(_AsyncIterator["Thread"]):
def __init__( def __init__(
self, self,
channel_id: int, channel_id: int,
@ -681,7 +681,7 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
self.http = guild._state.http self.http = guild._state.http
if joined and not private: if joined and not private:
raise ValueError('Cannot iterate over joined public archived threads') raise ValueError("Cannot iterate over joined public archived threads")
self.before: Optional[str] self.before: Optional[str]
if before is None: if before is None:
@ -721,11 +721,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
@staticmethod @staticmethod
def get_archive_timestamp(data: ThreadPayload) -> str: def get_archive_timestamp(data: ThreadPayload) -> str:
return data['thread_metadata']['archive_timestamp'] return data["thread_metadata"]["archive_timestamp"]
@staticmethod @staticmethod
def get_thread_id(data: ThreadPayload) -> str: def get_thread_id(data: ThreadPayload) -> str:
return data['id'] # type: ignore return data["id"] # type: ignore
async def fill_queue(self) -> None: async def fill_queue(self) -> None:
if not self.has_more: if not self.has_more:
@ -735,11 +735,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
data = await self.endpoint(self.channel_id, before=self.before, limit=limit) data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
# This stuff is obviously WIP because 'members' is always empty # This stuff is obviously WIP because 'members' is always empty
threads: List[ThreadPayload] = data.get('threads', []) threads: List[ThreadPayload] = data.get("threads", [])
for d in reversed(threads): for d in reversed(threads):
self.queue.put_nowait(self.create_thread(d)) self.queue.put_nowait(self.create_thread(d))
self.has_more = data.get('has_more', False) self.has_more = data.get("has_more", False)
if self.limit is not None: if self.limit is not None:
self.limit -= len(threads) self.limit -= len(threads)
if self.limit <= 0: if self.limit <= 0:
@ -750,4 +750,5 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']):
def create_thread(self, data: ThreadPayload) -> Thread: def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread from .threads import Thread
return Thread(guild=self.guild, state=self.guild._state, data=data) return Thread(guild=self.guild, state=self.guild._state, data=data)

View File

@ -44,8 +44,8 @@ from .colour import Colour
from .object import Object from .object import Object
__all__ = ( __all__ = (
'VoiceState', "VoiceState",
'Member', "Member",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -113,52 +113,54 @@ class VoiceState:
""" """
__slots__ = ( __slots__ = (
'session_id', "session_id",
'deaf', "deaf",
'mute', "mute",
'self_mute', "self_mute",
'self_stream', "self_stream",
'self_video', "self_video",
'self_deaf', "self_deaf",
'afk', "afk",
'channel', "channel",
'requested_to_speak_at', "requested_to_speak_at",
'suppress', "suppress",
) )
def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None): def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None):
self.session_id: str = data.get('session_id') self.session_id: str = data.get("session_id")
self._update(data, channel) self._update(data, channel)
def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]): def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]):
self.self_mute: bool = data.get('self_mute', False) self.self_mute: bool = data.get("self_mute", False)
self.self_deaf: bool = data.get('self_deaf', False) self.self_deaf: bool = data.get("self_deaf", False)
self.self_stream: bool = data.get('self_stream', False) self.self_stream: bool = data.get("self_stream", False)
self.self_video: bool = data.get('self_video', False) self.self_video: bool = data.get("self_video", False)
self.afk: bool = data.get('suppress', False) self.afk: bool = data.get("suppress", False)
self.mute: bool = data.get('mute', False) self.mute: bool = data.get("mute", False)
self.deaf: bool = data.get('deaf', False) self.deaf: bool = data.get("deaf", False)
self.suppress: bool = data.get('suppress', False) self.suppress: bool = data.get("suppress", False)
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(data.get('request_to_speak_timestamp')) self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(
data.get("request_to_speak_timestamp")
)
self.channel: Optional[VocalGuildChannel] = channel self.channel: Optional[VocalGuildChannel] = channel
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = [ attrs = [
('self_mute', self.self_mute), ("self_mute", self.self_mute),
('self_deaf', self.self_deaf), ("self_deaf", self.self_deaf),
('self_stream', self.self_stream), ("self_stream", self.self_stream),
('suppress', self.suppress), ("suppress", self.suppress),
('requested_to_speak_at', self.requested_to_speak_at), ("requested_to_speak_at", self.requested_to_speak_at),
('channel', self.channel), ("channel", self.channel),
] ]
inner = ' '.join('%s=%r' % t for t in attrs) inner = " ".join("%s=%r" % t for t in attrs)
return f'<{self.__class__.__name__} {inner}>' return f"<{self.__class__.__name__} {inner}>"
def flatten_user(cls): def flatten_user(cls):
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# ignore private/special methods # ignore private/special methods
if attr.startswith('_'): if attr.startswith("_"):
continue continue
# don't override what we already have # don't override what we already have
@ -167,9 +169,9 @@ def flatten_user(cls):
# if it's a slotted attribute or a property, redirect it # if it's a slotted attribute or a property, redirect it
# slotted members are implemented as member_descriptors in Type.__dict__ # slotted members are implemented as member_descriptors in Type.__dict__
if not hasattr(value, '__annotations__'): if not hasattr(value, "__annotations__"):
getter = attrgetter('_user.' + attr) getter = attrgetter("_user." + attr)
setattr(cls, attr, property(getter, doc=f'Equivalent to :attr:`User.{attr}`')) setattr(cls, attr, property(getter, doc=f"Equivalent to :attr:`User.{attr}`"))
else: else:
# Technically, this can also use attrgetter # Technically, this can also use attrgetter
# However I'm not sure how I feel about "functions" returning properties # However I'm not sure how I feel about "functions" returning properties
@ -197,7 +199,7 @@ def flatten_user(cls):
return cls return cls
M = TypeVar('M', bound='Member') M = TypeVar("M", bound="Member")
@flatten_user @flatten_user
@ -226,6 +228,10 @@ class Member(discord.abc.Messageable, _UserTag):
Returns the member's name with the discriminator. Returns the member's name with the discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes Attributes
---------- ----------
joined_at: Optional[:class:`datetime.datetime`] joined_at: Optional[:class:`datetime.datetime`]
@ -254,17 +260,17 @@ class Member(discord.abc.Messageable, _UserTag):
""" """
__slots__ = ( __slots__ = (
'_roles', "_roles",
'joined_at', "joined_at",
'premium_since', "premium_since",
'activities', "activities",
'guild', "guild",
'pending', "pending",
'nick', "nick",
'_client_status', "_client_status",
'_user', "_user",
'_state', "_state",
'_avatar', "_avatar",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -286,24 +292,27 @@ class Member(discord.abc.Messageable, _UserTag):
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState): def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState):
self._state: ConnectionState = state self._state: ConnectionState = state
self._user: User = state.store_user(data['user']) self._user: User = state.store_user(data["user"])
self.guild: Guild = guild self.guild: Guild = guild
self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at')) self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get("joined_at"))
self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since')) self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get("premium_since"))
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles'])) self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"]))
self._client_status: Dict[Optional[str], str] = {None: 'offline'} self._client_status: Dict[Optional[str], str] = {None: "offline"}
self.activities: Tuple[ActivityTypes, ...] = tuple() self.activities: Tuple[ActivityTypes, ...] = tuple()
self.nick: Optional[str] = data.get('nick', None) self.nick: Optional[str] = data.get("nick", None)
self.pending: bool = data.get('pending', False) self.pending: bool = data.get("pending", False)
self._avatar: Optional[str] = data.get('avatar') self._avatar: Optional[str] = data.get("avatar")
def __str__(self) -> str: def __str__(self) -> str:
return str(self._user) return str(self._user)
def __int__(self) -> int:
return self.id
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}' f"<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}"
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>' f" bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>"
) )
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
@ -318,25 +327,27 @@ class Member(discord.abc.Messageable, _UserTag):
@classmethod @classmethod
def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M: def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M:
author = message.author author = message.author
data['user'] = author._to_minimal_user_json() # type: ignore data["user"] = author._to_minimal_user_json() # type: ignore
return cls(data=data, guild=message.guild, state=message._state) # type: ignore return cls(data=data, guild=message.guild, state=message._state) # type: ignore
def _update_from_message(self, data: MemberPayload) -> None: def _update_from_message(self, data: MemberPayload) -> None:
self.joined_at = utils.parse_time(data.get('joined_at')) self.joined_at = utils.parse_time(data.get("joined_at"))
self.premium_since = utils.parse_time(data.get('premium_since')) self.premium_since = utils.parse_time(data.get("premium_since"))
self._roles = utils.SnowflakeList(map(int, data['roles'])) self._roles = utils.SnowflakeList(map(int, data["roles"]))
self.nick = data.get('nick', None) self.nick = data.get("nick", None)
self.pending = data.get('pending', False) self.pending = data.get("pending", False)
@classmethod @classmethod
def _try_upgrade(cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState) -> Union[User, M]: def _try_upgrade(
cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState
) -> Union[User, M]:
# A User object with a 'member' key # A User object with a 'member' key
try: try:
member_data = data.pop('member') member_data = data.pop("member")
except KeyError: except KeyError:
return state.create_user(data) return state.create_user(data)
else: else:
member_data['user'] = data # type: ignore member_data["user"] = data # type: ignore
return cls(data=member_data, guild=guild, state=state) # type: ignore return cls(data=member_data, guild=guild, state=state) # type: ignore
@classmethod @classmethod
@ -367,25 +378,25 @@ class Member(discord.abc.Messageable, _UserTag):
# the nickname change is optional, # the nickname change is optional,
# if it isn't in the payload then it didn't change # if it isn't in the payload then it didn't change
try: try:
self.nick = data['nick'] self.nick = data["nick"]
except KeyError: except KeyError:
pass pass
try: try:
self.pending = data['pending'] self.pending = data["pending"]
except KeyError: except KeyError:
pass pass
self.premium_since = utils.parse_time(data.get('premium_since')) self.premium_since = utils.parse_time(data.get("premium_since"))
self._roles = utils.SnowflakeList(map(int, data['roles'])) self._roles = utils.SnowflakeList(map(int, data["roles"]))
self._avatar = data.get('avatar') self._avatar = data.get("avatar")
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]: def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
self.activities = tuple(map(create_activity, data['activities'])) self.activities = tuple(map(create_activity, data["activities"]))
self._client_status = { self._client_status = {
sys.intern(key): sys.intern(value) for key, value in data.get('client_status', {}).items() # type: ignore sys.intern(key): sys.intern(value) for key, value in data.get("client_status", {}).items() # type: ignore
} }
self._client_status[None] = sys.intern(data['status']) self._client_status[None] = sys.intern(data["status"])
if len(user) > 1: if len(user) > 1:
return self._update_inner_user(user) return self._update_inner_user(user)
@ -395,7 +406,7 @@ class Member(discord.abc.Messageable, _UserTag):
u = self._user u = self._user
original = (u.name, u._avatar, u.discriminator, u._public_flags) original = (u.name, u._avatar, u.discriminator, u._public_flags)
# These keys seem to always be available # These keys seem to always be available
modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0)) modified = (user["username"], user["avatar"], user["discriminator"], user.get("public_flags", 0))
if original != modified: if original != modified:
to_return = User._copy(self._user) to_return = User._copy(self._user)
u.name, u._avatar, u.discriminator, u._public_flags = modified u.name, u._avatar, u.discriminator, u._public_flags = modified
@ -423,21 +434,21 @@ class Member(discord.abc.Messageable, _UserTag):
@property @property
def mobile_status(self) -> Status: def mobile_status(self) -> Status:
""":class:`Status`: The member's status on a mobile device, if applicable.""" """:class:`Status`: The member's status on a mobile device, if applicable."""
return try_enum(Status, self._client_status.get('mobile', 'offline')) return try_enum(Status, self._client_status.get("mobile", "offline"))
@property @property
def desktop_status(self) -> Status: def desktop_status(self) -> Status:
""":class:`Status`: The member's status on the desktop client, if applicable.""" """:class:`Status`: The member's status on the desktop client, if applicable."""
return try_enum(Status, self._client_status.get('desktop', 'offline')) return try_enum(Status, self._client_status.get("desktop", "offline"))
@property @property
def web_status(self) -> Status: def web_status(self) -> Status:
""":class:`Status`: The member's status on the web client, if applicable.""" """:class:`Status`: The member's status on the web client, if applicable."""
return try_enum(Status, self._client_status.get('web', 'offline')) return try_enum(Status, self._client_status.get("web", "offline"))
def is_on_mobile(self) -> bool: def is_on_mobile(self) -> bool:
""":class:`bool`: A helper function that determines if a member is active on a mobile device.""" """:class:`bool`: A helper function that determines if a member is active on a mobile device."""
return 'mobile' in self._client_status return "mobile" in self._client_status
@property @property
def colour(self) -> Colour: def colour(self) -> Colour:
@ -490,8 +501,8 @@ class Member(discord.abc.Messageable, _UserTag):
def mention(self) -> str: def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the member.""" """:class:`str`: Returns a string that allows you to mention the member."""
if self.nick: if self.nick:
return f'<@!{self._user.id}>' return f"<@!{self._user.id}>"
return f'<@{self._user.id}>' return f"<@{self._user.id}>"
@property @property
def display_name(self) -> str: def display_name(self) -> str:
@ -713,39 +724,39 @@ class Member(discord.abc.Messageable, _UserTag):
payload: Dict[str, Any] = {} payload: Dict[str, Any] = {}
if nick is not MISSING: if nick is not MISSING:
nick = nick or '' nick = nick or ""
if me: if me:
await http.change_my_nickname(guild_id, nick, reason=reason) await http.change_my_nickname(guild_id, nick, reason=reason)
else: else:
payload['nick'] = nick payload["nick"] = nick
if deafen is not MISSING: if deafen is not MISSING:
payload['deaf'] = deafen payload["deaf"] = deafen
if mute is not MISSING: if mute is not MISSING:
payload['mute'] = mute payload["mute"] = mute
if suppress is not MISSING: if suppress is not MISSING:
voice_state_payload = { voice_state_payload = {
'channel_id': self.voice.channel.id, "channel_id": self.voice.channel.id,
'suppress': suppress, "suppress": suppress,
} }
if suppress or self.bot: if suppress or self.bot:
voice_state_payload['request_to_speak_timestamp'] = None voice_state_payload["request_to_speak_timestamp"] = None
if me: if me:
await http.edit_my_voice_state(guild_id, voice_state_payload) await http.edit_my_voice_state(guild_id, voice_state_payload)
else: else:
if not suppress: if not suppress:
voice_state_payload['request_to_speak_timestamp'] = datetime.datetime.utcnow().isoformat() voice_state_payload["request_to_speak_timestamp"] = datetime.datetime.utcnow().isoformat()
await http.edit_voice_state(guild_id, self.id, voice_state_payload) await http.edit_voice_state(guild_id, self.id, voice_state_payload)
if voice_channel is not MISSING: if voice_channel is not MISSING:
payload['channel_id'] = voice_channel and voice_channel.id payload["channel_id"] = voice_channel and voice_channel.id
if roles is not MISSING: if roles is not MISSING:
payload['roles'] = tuple(r.id for r in roles) payload["roles"] = tuple(r.id for r in roles)
if payload: if payload:
data = await http.edit_member(guild_id, self.id, reason=reason, **payload) data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
@ -773,12 +784,12 @@ class Member(discord.abc.Messageable, _UserTag):
The operation failed. The operation failed.
""" """
payload = { payload = {
'channel_id': self.voice.channel.id, "channel_id": self.voice.channel.id,
'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(), "request_to_speak_timestamp": datetime.datetime.utcnow().isoformat(),
} }
if self._state.self_id != self.id: if self._state.self_id != self.id:
payload['suppress'] = False payload["suppress"] = False
await self._state.http.edit_voice_state(self.guild.id, self.id, payload) await self._state.http.edit_voice_state(self.guild.id, self.id, payload)
else: else:
await self._state.http.edit_my_voice_state(self.guild.id, payload) await self._state.http.edit_my_voice_state(self.guild.id, payload)

View File

@ -25,9 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union
__all__ = ( __all__ = ("AllowedMentions",)
'AllowedMentions',
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.message import AllowedMentions as AllowedMentionsPayload from .types.message import AllowedMentions as AllowedMentionsPayload
@ -36,7 +34,7 @@ if TYPE_CHECKING:
class _FakeBool: class _FakeBool:
def __repr__(self): def __repr__(self):
return 'True' return "True"
def __eq__(self, other): def __eq__(self, other):
return other is True return other is True
@ -47,7 +45,7 @@ class _FakeBool:
default: Any = _FakeBool() default: Any = _FakeBool()
A = TypeVar('A', bound='AllowedMentions') A = TypeVar("A", bound="AllowedMentions")
class AllowedMentions: class AllowedMentions:
@ -80,7 +78,7 @@ class AllowedMentions:
.. versionadded:: 1.6 .. versionadded:: 1.6
""" """
__slots__ = ('everyone', 'users', 'roles', 'replied_user') __slots__ = ("everyone", "users", "roles", "replied_user")
def __init__( def __init__(
self, self,
@ -116,22 +114,22 @@ class AllowedMentions:
data = {} data = {}
if self.everyone: if self.everyone:
parse.append('everyone') parse.append("everyone")
if self.users == True: if self.users == True:
parse.append('users') parse.append("users")
elif self.users != False: elif self.users != False:
data['users'] = [x.id for x in self.users] data["users"] = [x.id for x in self.users]
if self.roles == True: if self.roles == True:
parse.append('roles') parse.append("roles")
elif self.roles != False: elif self.roles != False:
data['roles'] = [x.id for x in self.roles] data["roles"] = [x.id for x in self.roles]
if self.replied_user: if self.replied_user:
data['replied_user'] = True data["replied_user"] = True
data['parse'] = parse data["parse"] = parse
return data # type: ignore return data # type: ignore
def merge(self, other: AllowedMentions) -> AllowedMentions: def merge(self, other: AllowedMentions) -> AllowedMentions:
@ -146,6 +144,6 @@ class AllowedMentions:
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'{self.__class__.__name__}(everyone={self.everyone}, ' f"{self.__class__.__name__}(everyone={self.everyone}, "
f'users={self.users}, roles={self.roles}, replied_user={self.replied_user})' f"users={self.users}, roles={self.roles}, replied_user={self.replied_user})"
) )

View File

@ -29,7 +29,21 @@ import datetime
import re import re
import io import io
from os import PathLike from os import PathLike
from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type from typing import (
Dict,
TYPE_CHECKING,
Union,
List,
Optional,
Any,
Callable,
Tuple,
ClassVar,
Optional,
overload,
TypeVar,
Type,
)
from . import utils from . import utils
from .reaction import Reaction from .reaction import Reaction
@ -76,15 +90,15 @@ if TYPE_CHECKING:
from .role import Role from .role import Role
from .ui.view import View from .ui.view import View
MR = TypeVar('MR', bound='MessageReference') MR = TypeVar("MR", bound="MessageReference")
EmojiInputType = Union[Emoji, PartialEmoji, str] EmojiInputType = Union[Emoji, PartialEmoji, str]
__all__ = ( __all__ = (
'Attachment', "Attachment",
'Message', "Message",
'PartialMessage', "PartialMessage",
'MessageReference', "MessageReference",
'DeletedReferencedMessage', "DeletedReferencedMessage",
) )
@ -93,15 +107,15 @@ def convert_emoji_reaction(emoji):
emoji = emoji.emoji emoji = emoji.emoji
if isinstance(emoji, Emoji): if isinstance(emoji, Emoji):
return f'{emoji.name}:{emoji.id}' return f"{emoji.name}:{emoji.id}"
if isinstance(emoji, PartialEmoji): if isinstance(emoji, PartialEmoji):
return emoji._as_reaction() return emoji._as_reaction()
if isinstance(emoji, str): if isinstance(emoji, str):
# Reactions can be in :name:id format, but not <:name:id>. # Reactions can be in :name:id format, but not <:name:id>.
# No existing emojis have <> in them, so this should be okay. # No existing emojis have <> in them, so this should be okay.
return emoji.strip('<>') return emoji.strip("<>")
raise InvalidArgument(f'emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.') raise InvalidArgument(f"emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.")
class Attachment(Hashable): class Attachment(Hashable):
@ -125,6 +139,10 @@ class Attachment(Hashable):
Returns the hash of the attachment. Returns the hash of the attachment.
.. describe:: int(x)
Returns the attachment's ID.
.. versionchanged:: 1.7 .. versionchanged:: 1.7
Attachment can now be casted to :class:`str` and is hashable. Attachment can now be casted to :class:`str` and is hashable.
@ -151,30 +169,36 @@ class Attachment(Hashable):
The attachment's `media type <https://en.wikipedia.org/wiki/Media_type>`_ The attachment's `media type <https://en.wikipedia.org/wiki/Media_type>`_
.. versionadded:: 1.7 .. versionadded:: 1.7
ephemeral: Optional[:class:`bool`]
If the attachment is ephemeral. Ephemeral attachments are temporary and
will automatically be removed after a set period of time.
.. versionadded:: 2.0
""" """
__slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type') __slots__ = ("id", "size", "height", "width", "filename", "url", "proxy_url", "ephemeral", "_http", "content_type")
def __init__(self, *, data: AttachmentPayload, state: ConnectionState): def __init__(self, *, data: AttachmentPayload, state: ConnectionState):
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.size: int = data['size'] self.size: int = data["size"]
self.height: Optional[int] = data.get('height') self.height: Optional[int] = data.get("height")
self.width: Optional[int] = data.get('width') self.width: Optional[int] = data.get("width")
self.filename: str = data['filename'] self.filename: str = data["filename"]
self.url: str = data.get('url') self.url: str = data.get("url")
self.proxy_url: str = data.get('proxy_url') self.proxy_url: str = data.get("proxy_url")
self._http = state.http self._http = state.http
self.content_type: Optional[str] = data.get('content_type') self.content_type: Optional[str] = data.get("content_type")
self.ephemeral: Optional[bool] = data.get("ephemeral")
def is_spoiler(self) -> bool: def is_spoiler(self) -> bool:
""":class:`bool`: Whether this attachment contains a spoiler.""" """:class:`bool`: Whether this attachment contains a spoiler."""
return self.filename.startswith('SPOILER_') return self.filename.startswith("SPOILER_")
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Attachment id={self.id} filename={self.filename!r} url={self.url!r}>' return f"<Attachment id={self.id} filename={self.filename!r} url={self.url!r}>"
def __str__(self) -> str: def __str__(self) -> str:
return self.url or '' return self.url or ""
async def save( async def save(
self, self,
@ -223,7 +247,7 @@ class Attachment(Hashable):
fp.seek(0) fp.seek(0)
return written return written
else: else:
with open(fp, 'wb') as f: with open(fp, "wb") as f:
return f.write(data) return f.write(data)
async def read(self, *, use_cached: bool = False) -> bytes: async def read(self, *, use_cached: bool = False) -> bytes:
@ -305,19 +329,19 @@ class Attachment(Hashable):
def to_dict(self) -> AttachmentPayload: def to_dict(self) -> AttachmentPayload:
result: AttachmentPayload = { result: AttachmentPayload = {
'filename': self.filename, "filename": self.filename,
'id': self.id, "id": self.id,
'proxy_url': self.proxy_url, "proxy_url": self.proxy_url,
'size': self.size, "size": self.size,
'url': self.url, "url": self.url,
'spoiler': self.is_spoiler(), "spoiler": self.is_spoiler(),
} }
if self.height: if self.height:
result['height'] = self.height result["height"] = self.height
if self.width: if self.width:
result['width'] = self.width result["width"] = self.width
if self.content_type: if self.content_type:
result['content_type'] = self.content_type result["content_type"] = self.content_type
return result return result
@ -331,7 +355,7 @@ class DeletedReferencedMessage:
.. versionadded:: 1.6 .. versionadded:: 1.6
""" """
__slots__ = ('_parent',) __slots__ = ("_parent",)
def __init__(self, parent: MessageReference): def __init__(self, parent: MessageReference):
self._parent: MessageReference = parent self._parent: MessageReference = parent
@ -343,7 +367,7 @@ class DeletedReferencedMessage:
def id(self) -> int: def id(self) -> int:
""":class:`int`: The message ID of the deleted referenced message.""" """:class:`int`: The message ID of the deleted referenced message."""
# the parent's message id won't be None here # the parent's message id won't be None here
return self._parent.message_id # type: ignore return self._parent.message_id # type: ignore
@property @property
def channel_id(self) -> int: def channel_id(self) -> int:
@ -390,9 +414,11 @@ class MessageReference:
.. versionadded:: 1.6 .. versionadded:: 1.6
""" """
__slots__ = ('message_id', 'channel_id', 'guild_id', 'fail_if_not_exists', 'resolved', '_state') __slots__ = ("message_id", "channel_id", "guild_id", "fail_if_not_exists", "resolved", "_state")
def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True): def __init__(
self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True
):
self._state: Optional[ConnectionState] = None self._state: Optional[ConnectionState] = None
self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None
self.message_id: Optional[int] = message_id self.message_id: Optional[int] = message_id
@ -403,10 +429,10 @@ class MessageReference:
@classmethod @classmethod
def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR: def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR:
self = cls.__new__(cls) self = cls.__new__(cls)
self.message_id = utils._get_as_snowflake(data, 'message_id') self.message_id = utils._get_as_snowflake(data, "message_id")
self.channel_id = int(data.pop('channel_id')) self.channel_id = int(data.pop("channel_id"))
self.guild_id = utils._get_as_snowflake(data, 'guild_id') self.guild_id = utils._get_as_snowflake(data, "guild_id")
self.fail_if_not_exists = data.get('fail_if_not_exists', True) self.fail_if_not_exists = data.get("fail_if_not_exists", True)
self._state = state self._state = state
self.resolved = None self.resolved = None
return self return self
@ -435,7 +461,7 @@ class MessageReference:
self = cls( self = cls(
message_id=message.id, message_id=message.id,
channel_id=message.channel.id, channel_id=message.channel.id,
guild_id=getattr(message.guild, 'id', None), guild_id=getattr(message.guild, "id", None),
fail_if_not_exists=fail_if_not_exists, fail_if_not_exists=fail_if_not_exists,
) )
self._state = message._state self._state = message._state
@ -452,36 +478,36 @@ class MessageReference:
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
guild_id = self.guild_id if self.guild_id is not None else '@me' guild_id = self.guild_id if self.guild_id is not None else "@me"
return f'https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}' return f"https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}"
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>' return f"<MessageReference message_id={self.message_id!r} channel_id={self.channel_id!r} guild_id={self.guild_id!r}>"
def to_dict(self) -> MessageReferencePayload: def to_dict(self) -> MessageReferencePayload:
result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {} result: MessageReferencePayload = {"message_id": self.message_id} if self.message_id is not None else {}
result['channel_id'] = self.channel_id result["channel_id"] = self.channel_id
if self.guild_id is not None: if self.guild_id is not None:
result['guild_id'] = self.guild_id result["guild_id"] = self.guild_id
if self.fail_if_not_exists is not None: if self.fail_if_not_exists is not None:
result['fail_if_not_exists'] = self.fail_if_not_exists result["fail_if_not_exists"] = self.fail_if_not_exists
return result return result
to_message_reference_dict = to_dict to_message_reference_dict = to_dict
def flatten_handlers(cls): def flatten_handlers(cls):
prefix = len('_handle_') prefix = len("_handle_")
handlers = [ handlers = [
(key[prefix:], value) (key[prefix:], value)
for key, value in cls.__dict__.items() for key, value in cls.__dict__.items()
if key.startswith('_handle_') and key != '_handle_member' if key.startswith("_handle_") and key != "_handle_member"
] ]
# store _handle_member last # store _handle_member last
handlers.append(('member', cls._handle_member)) handlers.append(("member", cls._handle_member))
cls._HANDLERS = handlers cls._HANDLERS = handlers
cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith('_cs_')] cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith("_cs_")]
return cls return cls
@ -503,6 +529,14 @@ class Message(Hashable):
Returns the message's hash. Returns the message's hash.
.. describe:: str(x)
Returns the message's content.
.. describe:: int(x)
Returns the message's ID.
Attributes Attributes
----------- -----------
tts: :class:`bool` tts: :class:`bool`
@ -603,36 +637,36 @@ class Message(Hashable):
""" """
__slots__ = ( __slots__ = (
'_state', "_state",
'_edited_timestamp', "_edited_timestamp",
'_cs_channel_mentions', "_cs_channel_mentions",
'_cs_raw_mentions', "_cs_raw_mentions",
'_cs_clean_content', "_cs_clean_content",
'_cs_raw_channel_mentions', "_cs_raw_channel_mentions",
'_cs_raw_role_mentions', "_cs_raw_role_mentions",
'_cs_system_content', "_cs_system_content",
'tts', "tts",
'content', "content",
'channel', "channel",
'webhook_id', "webhook_id",
'mention_everyone', "mention_everyone",
'embeds', "embeds",
'id', "id",
'mentions', "mentions",
'author', "author",
'attachments', "attachments",
'nonce', "nonce",
'pinned', "pinned",
'role_mentions', "role_mentions",
'type', "type",
'flags', "flags",
'reactions', "reactions",
'reference', "reference",
'application', "application",
'activity', "activity",
'stickers', "stickers",
'components', "components",
'guild', "guild",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -652,39 +686,39 @@ class Message(Hashable):
data: MessagePayload, data: MessagePayload,
): ):
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.webhook_id: Optional[int] = utils._get_as_snowflake(data, 'webhook_id') self.webhook_id: Optional[int] = utils._get_as_snowflake(data, "webhook_id")
self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get('reactions', [])] self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get("reactions", [])]
self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data['attachments']] self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data["attachments"]]
self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] self.embeds: List[Embed] = [Embed.from_dict(a) for a in data["embeds"]]
self.application: Optional[MessageApplicationPayload] = data.get('application') self.application: Optional[MessageApplicationPayload] = data.get("application")
self.activity: Optional[MessageActivityPayload] = data.get('activity') self.activity: Optional[MessageActivityPayload] = data.get("activity")
self.channel: MessageableChannel = channel self.channel: MessageableChannel = channel
self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp']) self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data["edited_timestamp"])
self.type: MessageType = try_enum(MessageType, data['type']) self.type: MessageType = try_enum(MessageType, data["type"])
self.pinned: bool = data['pinned'] self.pinned: bool = data["pinned"]
self.flags: MessageFlags = MessageFlags._from_value(data.get('flags', 0)) self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0))
self.mention_everyone: bool = data['mention_everyone'] self.mention_everyone: bool = data["mention_everyone"]
self.tts: bool = data['tts'] self.tts: bool = data["tts"]
self.content: str = data['content'] self.content: str = data["content"]
self.nonce: Optional[Union[int, str]] = data.get('nonce') self.nonce: Optional[Union[int, str]] = data.get("nonce")
self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get("sticker_items", [])]
self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])] self.components: List[Component] = [_component_factory(d) for d in data.get("components", [])]
try: try:
# if the channel doesn't have a guild attribute, we handle that # if the channel doesn't have a guild attribute, we handle that
self.guild = channel.guild # type: ignore self.guild = channel.guild # type: ignore
except AttributeError: except AttributeError:
self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id')) self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id"))
try: try:
ref = data['message_reference'] ref = data["message_reference"]
except KeyError: except KeyError:
self.reference = None self.reference = None
else: else:
self.reference = ref = MessageReference.with_state(state, ref) self.reference = ref = MessageReference.with_state(state, ref)
try: try:
resolved = data['referenced_message'] resolved = data["referenced_message"]
except KeyError: except KeyError:
pass pass
else: else:
@ -700,17 +734,18 @@ class Message(Hashable):
# the channel will be the correct type here # the channel will be the correct type here
ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore
for handler in ('author', 'member', 'mentions', 'mention_roles'): for handler in ("author", "member", "mentions", "mention_roles"):
try: try:
getattr(self, f'_handle_{handler}')(data[handler]) getattr(self, f"_handle_{handler}")(data[handler])
except KeyError: except KeyError:
continue continue
def __repr__(self) -> str: def __repr__(self) -> str:
name = self.__class__.__name__ name = self.__class__.__name__
return ( return f"<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>"
f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>'
) def __str__(self) -> Optional[str]:
return self.content
def _try_patch(self, data, key, transform=None) -> None: def _try_patch(self, data, key, transform=None) -> None:
try: try:
@ -725,7 +760,7 @@ class Message(Hashable):
def _add_reaction(self, data, emoji, user_id) -> Reaction: def _add_reaction(self, data, emoji, user_id) -> Reaction:
reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) reaction = utils.find(lambda r: r.emoji == emoji, self.reactions)
is_me = data['me'] = user_id == self._state.self_id is_me = data["me"] = user_id == self._state.self_id
if reaction is None: if reaction is None:
reaction = Reaction(message=self, data=data, emoji=emoji) reaction = Reaction(message=self, data=data, emoji=emoji)
@ -742,7 +777,7 @@ class Message(Hashable):
if reaction is None: if reaction is None:
# already removed? # already removed?
raise ValueError('Emoji already removed?') raise ValueError("Emoji already removed?")
# if reaction isn't in the list, we crash. This means discord # if reaction isn't in the list, we crash. This means discord
# sent bad data, or we stored improperly # sent bad data, or we stored improperly
@ -856,7 +891,7 @@ class Message(Hashable):
return return
for mention in filter(None, mentions): for mention in filter(None, mentions):
id_search = int(mention['id']) id_search = int(mention["id"])
member = guild.get_member(id_search) member = guild.get_member(id_search)
if member is not None: if member is not None:
r.append(member) r.append(member)
@ -878,7 +913,7 @@ class Message(Hashable):
self.guild = new_guild self.guild = new_guild
self.channel = new_channel self.channel = new_channel
@utils.cached_slot_property('_cs_raw_mentions') @utils.cached_slot_property("_cs_raw_mentions")
def raw_mentions(self) -> List[int]: def raw_mentions(self) -> List[int]:
"""List[:class:`int`]: A property that returns an array of user IDs matched with """List[:class:`int`]: A property that returns an array of user IDs matched with
the syntax of ``<@user_id>`` in the message content. the syntax of ``<@user_id>`` in the message content.
@ -886,30 +921,30 @@ class Message(Hashable):
This allows you to receive the user IDs of mentioned users This allows you to receive the user IDs of mentioned users
even in a private message context. even in a private message context.
""" """
return [int(x) for x in re.findall(r'<@!?([0-9]{15,20})>', self.content)] return [int(x) for x in re.findall(r"<@!?([0-9]{15,20})>", self.content)]
@utils.cached_slot_property('_cs_raw_channel_mentions') @utils.cached_slot_property("_cs_raw_channel_mentions")
def raw_channel_mentions(self) -> List[int]: def raw_channel_mentions(self) -> List[int]:
"""List[:class:`int`]: A property that returns an array of channel IDs matched with """List[:class:`int`]: A property that returns an array of channel IDs matched with
the syntax of ``<#channel_id>`` in the message content. the syntax of ``<#channel_id>`` in the message content.
""" """
return [int(x) for x in re.findall(r'<#([0-9]{15,20})>', self.content)] return [int(x) for x in re.findall(r"<#([0-9]{15,20})>", self.content)]
@utils.cached_slot_property('_cs_raw_role_mentions') @utils.cached_slot_property("_cs_raw_role_mentions")
def raw_role_mentions(self) -> List[int]: def raw_role_mentions(self) -> List[int]:
"""List[:class:`int`]: A property that returns an array of role IDs matched with """List[:class:`int`]: A property that returns an array of role IDs matched with
the syntax of ``<@&role_id>`` in the message content. the syntax of ``<@&role_id>`` in the message content.
""" """
return [int(x) for x in re.findall(r'<@&([0-9]{15,20})>', self.content)] return [int(x) for x in re.findall(r"<@&([0-9]{15,20})>", self.content)]
@utils.cached_slot_property('_cs_channel_mentions') @utils.cached_slot_property("_cs_channel_mentions")
def channel_mentions(self) -> List[GuildChannel]: def channel_mentions(self) -> List[GuildChannel]:
if self.guild is None: if self.guild is None:
return [] return []
it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions)) it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions))
return utils._unique(it) return utils._unique(it)
@utils.cached_slot_property('_cs_clean_content') @utils.cached_slot_property("_cs_clean_content")
def clean_content(self) -> str: def clean_content(self) -> str:
""":class:`str`: A property that returns the content in a "cleaned up" """:class:`str`: A property that returns the content in a "cleaned up"
manner. This basically means that mentions are transformed manner. This basically means that mentions are transformed
@ -956,9 +991,9 @@ class Message(Hashable):
# fmt: on # fmt: on
def repl(obj): def repl(obj):
return transformations.get(re.escape(obj.group(0)), '') return transformations.get(re.escape(obj.group(0)), "")
pattern = re.compile('|'.join(transformations.keys())) pattern = re.compile("|".join(transformations.keys()))
result = pattern.sub(repl, self.content) result = pattern.sub(repl, self.content)
return escape_mentions(result) return escape_mentions(result)
@ -975,8 +1010,8 @@ class Message(Hashable):
@property @property
def jump_url(self) -> str: def jump_url(self) -> str:
""":class:`str`: Returns a URL that allows the client to jump to this message.""" """:class:`str`: Returns a URL that allows the client to jump to this message."""
guild_id = getattr(self.guild, 'id', '@me') guild_id = getattr(self.guild, "id", "@me")
return f'https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}' return f"https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}"
def is_system(self) -> bool: def is_system(self) -> bool:
""":class:`bool`: Whether the message is a system message. """:class:`bool`: Whether the message is a system message.
@ -993,7 +1028,7 @@ class Message(Hashable):
MessageType.thread_starter_message, MessageType.thread_starter_message,
) )
@utils.cached_slot_property('_cs_system_content') @utils.cached_slot_property("_cs_system_content")
def system_content(self): def system_content(self):
r""":class:`str`: A property that returns the content that is rendered r""":class:`str`: A property that returns the content that is rendered
regardless of the :attr:`Message.type`. regardless of the :attr:`Message.type`.
@ -1008,24 +1043,24 @@ class Message(Hashable):
if self.type is MessageType.recipient_add: if self.type is MessageType.recipient_add:
if self.channel.type is ChannelType.group: if self.channel.type is ChannelType.group:
return f'{self.author.name} added {self.mentions[0].name} to the group.' return f"{self.author.name} added {self.mentions[0].name} to the group."
else: else:
return f'{self.author.name} added {self.mentions[0].name} to the thread.' return f"{self.author.name} added {self.mentions[0].name} to the thread."
if self.type is MessageType.recipient_remove: if self.type is MessageType.recipient_remove:
if self.channel.type is ChannelType.group: if self.channel.type is ChannelType.group:
return f'{self.author.name} removed {self.mentions[0].name} from the group.' return f"{self.author.name} removed {self.mentions[0].name} from the group."
else: else:
return f'{self.author.name} removed {self.mentions[0].name} from the thread.' return f"{self.author.name} removed {self.mentions[0].name} from the thread."
if self.type is MessageType.channel_name_change: if self.type is MessageType.channel_name_change:
return f'{self.author.name} changed the channel name: **{self.content}**' return f"{self.author.name} changed the channel name: **{self.content}**"
if self.type is MessageType.channel_icon_change: if self.type is MessageType.channel_icon_change:
return f'{self.author.name} changed the channel icon.' return f"{self.author.name} changed the channel icon."
if self.type is MessageType.pins_add: if self.type is MessageType.pins_add:
return f'{self.author.name} pinned a message to this channel.' return f"{self.author.name} pinned a message to this channel."
if self.type is MessageType.new_member: if self.type is MessageType.new_member:
formats = [ formats = [
@ -1049,64 +1084,64 @@ class Message(Hashable):
if self.type is MessageType.premium_guild_subscription: if self.type is MessageType.premium_guild_subscription:
if not self.content: if not self.content:
return f'{self.author.name} just boosted the server!' return f"{self.author.name} just boosted the server!"
else: else:
return f'{self.author.name} just boosted the server **{self.content}** times!' return f"{self.author.name} just boosted the server **{self.content}** times!"
if self.type is MessageType.premium_guild_tier_1: if self.type is MessageType.premium_guild_tier_1:
if not self.content: if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**' return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**"
else: else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**' return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**"
if self.type is MessageType.premium_guild_tier_2: if self.type is MessageType.premium_guild_tier_2:
if not self.content: if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**' return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**"
else: else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**' return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**"
if self.type is MessageType.premium_guild_tier_3: if self.type is MessageType.premium_guild_tier_3:
if not self.content: if not self.content:
return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**' return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**"
else: else:
return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**' return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**"
if self.type is MessageType.channel_follow_add: if self.type is MessageType.channel_follow_add:
return f'{self.author.name} has added {self.content} to this channel' return f"{self.author.name} has added {self.content} to this channel"
if self.type is MessageType.guild_stream: if self.type is MessageType.guild_stream:
# the author will be a Member # the author will be a Member
return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore return f"{self.author.name} is live! Now streaming {self.author.activity.name}" # type: ignore
if self.type is MessageType.guild_discovery_disqualified: if self.type is MessageType.guild_discovery_disqualified:
return 'This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details.' return "This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details."
if self.type is MessageType.guild_discovery_requalified: if self.type is MessageType.guild_discovery_requalified:
return 'This server is eligible for Server Discovery again and has been automatically relisted!' return "This server is eligible for Server Discovery again and has been automatically relisted!"
if self.type is MessageType.guild_discovery_grace_period_initial_warning: if self.type is MessageType.guild_discovery_grace_period_initial_warning:
return 'This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery.' return "This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery."
if self.type is MessageType.guild_discovery_grace_period_final_warning: if self.type is MessageType.guild_discovery_grace_period_final_warning:
return 'This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery.' return "This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery."
if self.type is MessageType.thread_created: if self.type is MessageType.thread_created:
return f'{self.author.name} started a thread: **{self.content}**. See all **threads**.' return f"{self.author.name} started a thread: **{self.content}**. See all **threads**."
if self.type is MessageType.reply: if self.type is MessageType.reply:
return self.content return self.content
if self.type is MessageType.thread_starter_message: if self.type is MessageType.thread_starter_message:
if self.reference is None or self.reference.resolved is None: if self.reference is None or self.reference.resolved is None:
return 'Sorry, we couldn\'t load the first message in this thread' return "Sorry, we couldn't load the first message in this thread"
# the resolved message for the reference will be a Message # the resolved message for the reference will be a Message
return self.reference.resolved.content # type: ignore return self.reference.resolved.content # type: ignore
if self.type is MessageType.guild_invite_reminder: if self.type is MessageType.guild_invite_reminder:
return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!' return "Wondering who to invite?\nStart by inviting anyone who can help you build the server!"
async def delete(self, *, delay: Optional[float] = None) -> None: async def delete(self, *, delay: Optional[float] = None, silent: bool = False) -> None:
"""|coro| """|coro|
Deletes the message. Deletes the message.
@ -1117,12 +1152,17 @@ class Message(Hashable):
.. versionchanged:: 1.1 .. versionchanged:: 1.1
Added the new ``delay`` keyword-only parameter. Added the new ``delay`` keyword-only parameter.
.. versionchanged:: 2.0
Added the new ``silent`` keyword-only parameter.
Parameters Parameters
----------- -----------
delay: Optional[:class:`float`] delay: Optional[:class:`float`]
If provided, the number of seconds to wait in the background If provided, the number of seconds to wait in the background
before deleting the message. If the deletion fails then it is silently ignored. before deleting the message. If the deletion fails then it is silently ignored.
silent: :class:`bool`
If silent is set to ``True``, the error will not be raised, it will be ignored.
This defaults to ``False``
Raises Raises
------ ------
@ -1144,7 +1184,11 @@ class Message(Hashable):
asyncio.create_task(delete(delay)) asyncio.create_task(delete(delay))
else: else:
await self._state.http.delete_message(self.channel.id, self.id) try:
await self._state.http.delete_message(self.channel.id, self.id)
except Exception:
if not silent:
raise
@overload @overload
async def edit( async def edit(
@ -1246,45 +1290,45 @@ class Message(Hashable):
payload: Dict[str, Any] = {} payload: Dict[str, Any] = {}
if content is not MISSING: if content is not MISSING:
if content is not None: if content is not None:
payload['content'] = str(content) payload["content"] = str(content)
else: else:
payload['content'] = None payload["content"] = None
if embed is not MISSING and embeds is not MISSING: if embed is not MISSING and embeds is not MISSING:
raise InvalidArgument('cannot pass both embed and embeds parameter to edit()') raise InvalidArgument("cannot pass both embed and embeds parameter to edit()")
if embed is not MISSING: if embed is not MISSING:
if embed is None: if embed is None:
payload['embeds'] = [] payload["embeds"] = []
else: else:
payload['embeds'] = [embed.to_dict()] payload["embeds"] = [embed.to_dict()]
elif embeds is not MISSING: elif embeds is not MISSING:
payload['embeds'] = [e.to_dict() for e in embeds] payload["embeds"] = [e.to_dict() for e in embeds]
if suppress is not MISSING: if suppress is not MISSING:
flags = MessageFlags._from_value(self.flags.value) flags = MessageFlags._from_value(self.flags.value)
flags.suppress_embeds = suppress flags.suppress_embeds = suppress
payload['flags'] = flags.value payload["flags"] = flags.value
if allowed_mentions is MISSING: if allowed_mentions is MISSING:
if self._state.allowed_mentions is not None and self.author.id == self._state.self_id: if self._state.allowed_mentions is not None and self.author.id == self._state.self_id:
payload['allowed_mentions'] = self._state.allowed_mentions.to_dict() payload["allowed_mentions"] = self._state.allowed_mentions.to_dict()
else: else:
if allowed_mentions is not None: if allowed_mentions is not None:
if self._state.allowed_mentions is not None: if self._state.allowed_mentions is not None:
payload['allowed_mentions'] = self._state.allowed_mentions.merge(allowed_mentions).to_dict() payload["allowed_mentions"] = self._state.allowed_mentions.merge(allowed_mentions).to_dict()
else: else:
payload['allowed_mentions'] = allowed_mentions.to_dict() payload["allowed_mentions"] = allowed_mentions.to_dict()
if attachments is not MISSING: if attachments is not MISSING:
payload['attachments'] = [a.to_dict() for a in attachments] payload["attachments"] = [a.to_dict() for a in attachments]
if view is not MISSING: if view is not MISSING:
self._state.prevent_view_updates_for(self.id) self._state.prevent_view_updates_for(self.id)
if view: if view:
payload['components'] = view.to_components() payload["components"] = view.to_components()
else: else:
payload['components'] = [] payload["components"] = []
data = await self._state.http.edit_message(self.channel.id, self.id, **payload) data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
message = Message(state=self._state, channel=self.channel, data=data) message = Message(state=self._state, channel=self.channel, data=data)
@ -1526,9 +1570,11 @@ class Message(Hashable):
The created thread. The created thread.
""" """
if self.guild is None: if self.guild is None:
raise InvalidArgument('This message does not have guild info attached.') raise InvalidArgument("This message does not have guild info attached.")
default_auto_archive_duration: ThreadArchiveDuration = getattr(self.channel, 'default_auto_archive_duration', 1440) default_auto_archive_duration: ThreadArchiveDuration = getattr(
self.channel, "default_auto_archive_duration", 1440
)
data = await self._state.http.start_thread_with_message( data = await self._state.http.start_thread_with_message(
self.channel.id, self.channel.id,
self.id, self.id,
@ -1586,12 +1632,12 @@ class Message(Hashable):
def to_message_reference_dict(self) -> MessageReferencePayload: def to_message_reference_dict(self) -> MessageReferencePayload:
data: MessageReferencePayload = { data: MessageReferencePayload = {
'message_id': self.id, "message_id": self.id,
'channel_id': self.channel.id, "channel_id": self.channel.id,
} }
if self.guild is not None: if self.guild is not None:
data['guild_id'] = self.guild.id data["guild_id"] = self.guild.id
return data return data
@ -1625,6 +1671,10 @@ class PartialMessage(Hashable):
Returns the partial message's hash. Returns the partial message's hash.
.. describe:: int(x)
Returns the partial message's ID.
Attributes Attributes
----------- -----------
channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`] channel: Union[:class:`TextChannel`, :class:`Thread`, :class:`DMChannel`]
@ -1633,7 +1683,7 @@ class PartialMessage(Hashable):
The message ID. The message ID.
""" """
__slots__ = ('channel', 'id', '_cs_guild', '_state') __slots__ = ("channel", "id", "_cs_guild", "_state")
jump_url: str = Message.jump_url # type: ignore jump_url: str = Message.jump_url # type: ignore
delete = Message.delete delete = Message.delete
@ -1657,7 +1707,7 @@ class PartialMessage(Hashable):
ChannelType.public_thread, ChannelType.public_thread,
ChannelType.private_thread, ChannelType.private_thread,
): ):
raise TypeError(f'Expected TextChannel, DMChannel or Thread not {type(channel)!r}') raise TypeError(f"Expected TextChannel, DMChannel or Thread not {type(channel)!r}")
self.channel: PartialMessageableChannel = channel self.channel: PartialMessageableChannel = channel
self._state: ConnectionState = channel._state self._state: ConnectionState = channel._state
@ -1673,17 +1723,17 @@ class PartialMessage(Hashable):
pinned = property(None, lambda x, y: None) pinned = property(None, lambda x, y: None)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<PartialMessage id={self.id} channel={self.channel!r}>' return f"<PartialMessage id={self.id} channel={self.channel!r}>"
@property @property
def created_at(self) -> datetime.datetime: def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: The partial message's creation time in UTC.""" """:class:`datetime.datetime`: The partial message's creation time in UTC."""
return utils.snowflake_time(self.id) return utils.snowflake_time(self.id)
@utils.cached_slot_property('_cs_guild') @utils.cached_slot_property("_cs_guild")
def guild(self) -> Optional[Guild]: def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that the partial message belongs to, if applicable.""" """Optional[:class:`Guild`]: The guild that the partial message belongs to, if applicable."""
return getattr(self.channel, 'guild', None) return getattr(self.channel, "guild", None)
async def fetch(self) -> Message: async def fetch(self) -> Message:
"""|coro| """|coro|
@ -1765,34 +1815,34 @@ class PartialMessage(Hashable):
""" """
try: try:
content = fields['content'] content = fields["content"]
except KeyError: except KeyError:
pass pass
else: else:
if content is not None: if content is not None:
fields['content'] = str(content) fields["content"] = str(content)
try: try:
embed = fields['embed'] embed = fields["embed"]
except KeyError: except KeyError:
pass pass
else: else:
if embed is not None: if embed is not None:
fields['embed'] = embed.to_dict() fields["embed"] = embed.to_dict()
try: try:
suppress: bool = fields.pop('suppress') suppress: bool = fields.pop("suppress")
except KeyError: except KeyError:
pass pass
else: else:
flags = MessageFlags._from_value(0) flags = MessageFlags._from_value(0)
flags.suppress_embeds = suppress flags.suppress_embeds = suppress
fields['flags'] = flags.value fields["flags"] = flags.value
delete_after = fields.pop('delete_after', None) delete_after = fields.pop("delete_after", None)
try: try:
allowed_mentions = fields.pop('allowed_mentions') allowed_mentions = fields.pop("allowed_mentions")
except KeyError: except KeyError:
pass pass
else: else:
@ -1801,19 +1851,19 @@ class PartialMessage(Hashable):
allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions).to_dict() allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions).to_dict()
else: else:
allowed_mentions = allowed_mentions.to_dict() allowed_mentions = allowed_mentions.to_dict()
fields['allowed_mentions'] = allowed_mentions fields["allowed_mentions"] = allowed_mentions
try: try:
view = fields.pop('view') view = fields.pop("view")
except KeyError: except KeyError:
# To check for the view afterwards # To check for the view afterwards
view = None view = None
else: else:
self._state.prevent_view_updates_for(self.id) self._state.prevent_view_updates_for(self.id)
if view: if view:
fields['components'] = view.to_components() fields["components"] = view.to_components()
else: else:
fields['components'] = [] fields["components"] = []
if fields: if fields:
data = await self._state.http.edit_message(self.channel.id, self.id, **fields) data = await self._state.http.edit_message(self.channel.id, self.id, **fields)

View File

@ -23,10 +23,11 @@ DEALINGS IN THE SOFTWARE.
""" """
__all__ = ( __all__ = (
'EqualityComparable', "EqualityComparable",
'Hashable', "Hashable",
) )
class EqualityComparable: class EqualityComparable:
__slots__ = () __slots__ = ()
@ -40,8 +41,12 @@ class EqualityComparable:
return other.id != self.id return other.id != self.id
return True return True
class Hashable(EqualityComparable): class Hashable(EqualityComparable):
__slots__ = () __slots__ = ()
def __int__(self) -> int:
return self.id
def __hash__(self) -> int: def __hash__(self) -> int:
return self.id >> 22 return self.id >> 22

View File

@ -35,11 +35,11 @@ from typing import (
if TYPE_CHECKING: if TYPE_CHECKING:
import datetime import datetime
SupportsIntCast = Union[SupportsInt, str, bytes, bytearray] SupportsIntCast = Union[SupportsInt, str, bytes, bytearray]
__all__ = ( __all__ = ("Object",)
'Object',
)
class Object(Hashable): class Object(Hashable):
"""Represents a generic Discord object. """Represents a generic Discord object.
@ -69,6 +69,10 @@ class Object(Hashable):
Returns the object's hash. Returns the object's hash.
.. describe:: int(x)
Returns the object's ID.
Attributes Attributes
----------- -----------
id: :class:`int` id: :class:`int`
@ -79,12 +83,12 @@ class Object(Hashable):
try: try:
id = int(id) id = int(id)
except ValueError: except ValueError:
raise TypeError(f'id parameter must be convertable to int not {id.__class__!r}') from None raise TypeError(f"id parameter must be convertable to int not {id.__class__!r}") from None
else: else:
self.id = id self.id = id
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Object id={self.id!r}>' return f"<Object id={self.id!r}>"
@property @property
def created_at(self) -> datetime.datetime: def created_at(self) -> datetime.datetime:

View File

@ -31,20 +31,24 @@ from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional
from .errors import DiscordException from .errors import DiscordException
__all__ = ( __all__ = (
'OggError', "OggError",
'OggPage', "OggPage",
'OggStream', "OggStream",
) )
class OggError(DiscordException): class OggError(DiscordException):
"""An exception that is thrown for Ogg stream parsing errors.""" """An exception that is thrown for Ogg stream parsing errors."""
pass pass
# https://tools.ietf.org/html/rfc3533 # https://tools.ietf.org/html/rfc3533
# https://tools.ietf.org/html/rfc7845 # https://tools.ietf.org/html/rfc7845
class OggPage: class OggPage:
_header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB') _header: ClassVar[struct.Struct] = struct.Struct("<xBQIIIB")
if TYPE_CHECKING: if TYPE_CHECKING:
flag: int flag: int
gran_pos: int gran_pos: int
@ -57,14 +61,13 @@ class OggPage:
try: try:
header = stream.read(struct.calcsize(self._header.format)) header = stream.read(struct.calcsize(self._header.format))
self.flag, self.gran_pos, self.serial, \ self.flag, self.gran_pos, self.serial, self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.segtable: bytes = stream.read(self.segnum) self.segtable: bytes = stream.read(self.segnum)
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable)) bodylen = sum(struct.unpack("B" * self.segnum, self.segtable))
self.data: bytes = stream.read(bodylen) self.data: bytes = stream.read(bodylen)
except Exception: except Exception:
raise OggError('bad data stream') from None raise OggError("bad data stream") from None
def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]: def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]:
packetlen = offset = 0 packetlen = offset = 0
@ -76,7 +79,7 @@ class OggPage:
partial = True partial = True
else: else:
packetlen += seg packetlen += seg
yield self.data[offset:offset+packetlen], True yield self.data[offset : offset + packetlen], True
offset += packetlen offset += packetlen
packetlen = 0 packetlen = 0
partial = False partial = False
@ -84,18 +87,19 @@ class OggPage:
if partial: if partial:
yield self.data[offset:], False yield self.data[offset:], False
class OggStream: class OggStream:
def __init__(self, stream: IO[bytes]) -> None: def __init__(self, stream: IO[bytes]) -> None:
self.stream: IO[bytes] = stream self.stream: IO[bytes] = stream
def _next_page(self) -> Optional[OggPage]: def _next_page(self) -> Optional[OggPage]:
head = self.stream.read(4) head = self.stream.read(4)
if head == b'OggS': if head == b"OggS":
return OggPage(self.stream) return OggPage(self.stream)
elif not head: elif not head:
return None return None
else: else:
raise OggError('invalid header magic') raise OggError("invalid header magic")
def _iter_pages(self) -> Generator[OggPage, None, None]: def _iter_pages(self) -> Generator[OggPage, None, None]:
page = self._next_page() page = self._next_page()
@ -104,10 +108,10 @@ class OggStream:
page = self._next_page() page = self._next_page()
def iter_packets(self) -> Generator[bytes, None, None]: def iter_packets(self) -> Generator[bytes, None, None]:
partial = b'' partial = b""
for page in self._iter_pages(): for page in self._iter_pages():
for data, complete in page.iter_packets(): for data, complete in page.iter_packets():
partial += data partial += data
if complete: if complete:
yield partial yield partial
partial = b'' partial = b""

View File

@ -38,9 +38,10 @@ import sys
from .errors import DiscordException, InvalidArgument from .errors import DiscordException, InvalidArgument
if TYPE_CHECKING: if TYPE_CHECKING:
T = TypeVar('T') T = TypeVar("T")
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full'] BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"]
SIGNAL_CTL = Literal['auto', 'voice', 'music'] SIGNAL_CTL = Literal["auto", "voice", "music"]
class BandCtl(TypedDict): class BandCtl(TypedDict):
narrow: int narrow: int
@ -49,81 +50,89 @@ class BandCtl(TypedDict):
superwide: int superwide: int
full: int full: int
class SignalCtl(TypedDict): class SignalCtl(TypedDict):
auto: int auto: int
voice: int voice: int
music: int music: int
__all__ = ( __all__ = (
'Encoder', "Encoder",
'OpusError', "OpusError",
'OpusNotLoaded', "OpusNotLoaded",
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
c_float_ptr = ctypes.POINTER(ctypes.c_float) c_float_ptr = ctypes.POINTER(ctypes.c_float)
_lib = None _lib = None
class EncoderStruct(ctypes.Structure): class EncoderStruct(ctypes.Structure):
pass pass
class DecoderStruct(ctypes.Structure): class DecoderStruct(ctypes.Structure):
pass pass
EncoderStructPtr = ctypes.POINTER(EncoderStruct) EncoderStructPtr = ctypes.POINTER(EncoderStruct)
DecoderStructPtr = ctypes.POINTER(DecoderStruct) DecoderStructPtr = ctypes.POINTER(DecoderStruct)
## Some constants from opus_defines.h ## Some constants from opus_defines.h
# Error codes # Error codes
OK = 0 OK = 0
BAD_ARG = -1 BAD_ARG = -1
# Encoder CTLs # Encoder CTLs
APPLICATION_AUDIO = 2049 APPLICATION_AUDIO = 2049
APPLICATION_VOIP = 2048 APPLICATION_VOIP = 2048
APPLICATION_LOWDELAY = 2051 APPLICATION_LOWDELAY = 2051
CTL_SET_BITRATE = 4002 CTL_SET_BITRATE = 4002
CTL_SET_BANDWIDTH = 4008 CTL_SET_BANDWIDTH = 4008
CTL_SET_FEC = 4012 CTL_SET_FEC = 4012
CTL_SET_PLP = 4014 CTL_SET_PLP = 4014
CTL_SET_SIGNAL = 4024 CTL_SET_SIGNAL = 4024
# Decoder CTLs # Decoder CTLs
CTL_SET_GAIN = 4034 CTL_SET_GAIN = 4034
CTL_LAST_PACKET_DURATION = 4039 CTL_LAST_PACKET_DURATION = 4039
band_ctl: BandCtl = { band_ctl: BandCtl = {
'narrow': 1101, "narrow": 1101,
'medium': 1102, "medium": 1102,
'wide': 1103, "wide": 1103,
'superwide': 1104, "superwide": 1104,
'full': 1105, "full": 1105,
} }
signal_ctl: SignalCtl = { signal_ctl: SignalCtl = {
'auto': -1000, "auto": -1000,
'voice': 3001, "voice": 3001,
'music': 3002, "music": 3002,
} }
def _err_lt(result: int, func: Callable, args: List) -> int: def _err_lt(result: int, func: Callable, args: List) -> int:
if result < OK: if result < OK:
_log.info('error has happened in %s', func.__name__) _log.info("error has happened in %s", func.__name__)
raise OpusError(result) raise OpusError(result)
return result return result
def _err_ne(result: T, func: Callable, args: List) -> T: def _err_ne(result: T, func: Callable, args: List) -> T:
ret = args[-1]._obj ret = args[-1]._obj
if ret.value != OK: if ret.value != OK:
_log.info('error has happened in %s', func.__name__) _log.info("error has happened in %s", func.__name__)
raise OpusError(ret.value) raise OpusError(ret.value)
return result return result
# A list of exported functions. # A list of exported functions.
# The first argument is obviously the name. # The first argument is obviously the name.
# The second one are the types of arguments it takes. # The second one are the types of arguments it takes.
@ -131,54 +140,51 @@ def _err_ne(result: T, func: Callable, args: List) -> T:
# The fourth is the error handler. # The fourth is the error handler.
exported_functions: List[Tuple[Any, ...]] = [ exported_functions: List[Tuple[Any, ...]] = [
# Generic # Generic
('opus_get_version_string', ("opus_get_version_string", None, ctypes.c_char_p, None),
None, ctypes.c_char_p, None), ("opus_strerror", [ctypes.c_int], ctypes.c_char_p, None),
('opus_strerror',
[ctypes.c_int], ctypes.c_char_p, None),
# Encoder functions # Encoder functions
('opus_encoder_get_size', ("opus_encoder_get_size", [ctypes.c_int], ctypes.c_int, None),
[ctypes.c_int], ctypes.c_int, None), ("opus_encoder_create", [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne),
('opus_encoder_create', (
[ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne), "opus_encode",
('opus_encode', [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), ctypes.c_int32,
('opus_encode_float', _err_lt,
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), ),
('opus_encoder_ctl', (
None, ctypes.c_int32, _err_lt), "opus_encode_float",
('opus_encoder_destroy', [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
[EncoderStructPtr], None, None), ctypes.c_int32,
_err_lt,
),
("opus_encoder_ctl", None, ctypes.c_int32, _err_lt),
("opus_encoder_destroy", [EncoderStructPtr], None, None),
# Decoder functions # Decoder functions
('opus_decoder_get_size', ("opus_decoder_get_size", [ctypes.c_int], ctypes.c_int, None),
[ctypes.c_int], ctypes.c_int, None), ("opus_decoder_create", [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne),
('opus_decoder_create', (
[ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne), "opus_decode",
('opus_decode',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int], [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int, _err_lt), ctypes.c_int,
('opus_decode_float', _err_lt,
),
(
"opus_decode_float",
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int], [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int, _err_lt), ctypes.c_int,
('opus_decoder_ctl', _err_lt,
None, ctypes.c_int32, _err_lt), ),
('opus_decoder_destroy', ("opus_decoder_ctl", None, ctypes.c_int32, _err_lt),
[DecoderStructPtr], None, None), ("opus_decoder_destroy", [DecoderStructPtr], None, None),
('opus_decoder_get_nb_samples', ("opus_decoder_get_nb_samples", [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
# Packet functions # Packet functions
('opus_packet_get_bandwidth', ("opus_packet_get_bandwidth", [ctypes.c_char_p], ctypes.c_int, _err_lt),
[ctypes.c_char_p], ctypes.c_int, _err_lt), ("opus_packet_get_nb_channels", [ctypes.c_char_p], ctypes.c_int, _err_lt),
('opus_packet_get_nb_channels', ("opus_packet_get_nb_frames", [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
[ctypes.c_char_p], ctypes.c_int, _err_lt), ("opus_packet_get_samples_per_frame", [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
('opus_packet_get_nb_frames',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
('opus_packet_get_samples_per_frame',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
] ]
def libopus_loader(name: str) -> Any: def libopus_loader(name: str) -> Any:
# create the library... # create the library...
lib = ctypes.cdll.LoadLibrary(name) lib = ctypes.cdll.LoadLibrary(name)
@ -203,22 +209,24 @@ def libopus_loader(name: str) -> Any:
return lib return lib
def _load_default() -> bool: def _load_default() -> bool:
global _lib global _lib
try: try:
if sys.platform == 'win32': if sys.platform == "win32":
_basedir = os.path.dirname(os.path.abspath(__file__)) _basedir = os.path.dirname(os.path.abspath(__file__))
_bitness = struct.calcsize('P') * 8 _bitness = struct.calcsize("P") * 8
_target = 'x64' if _bitness > 32 else 'x86' _target = "x64" if _bitness > 32 else "x86"
_filename = os.path.join(_basedir, 'bin', f'libopus-0.{_target}.dll') _filename = os.path.join(_basedir, "bin", f"libopus-0.{_target}.dll")
_lib = libopus_loader(_filename) _lib = libopus_loader(_filename)
else: else:
_lib = libopus_loader(ctypes.util.find_library('opus')) _lib = libopus_loader(ctypes.util.find_library("opus"))
except Exception: except Exception:
_lib = None _lib = None
return _lib is not None return _lib is not None
def load_opus(name: str) -> None: def load_opus(name: str) -> None:
"""Loads the libopus shared library for use with voice. """Loads the libopus shared library for use with voice.
@ -257,6 +265,7 @@ def load_opus(name: str) -> None:
global _lib global _lib
_lib = libopus_loader(name) _lib = libopus_loader(name)
def is_loaded() -> bool: def is_loaded() -> bool:
"""Function to check if opus lib is successfully loaded either """Function to check if opus lib is successfully loaded either
via the :func:`ctypes.util.find_library` call of :func:`load_opus`. via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
@ -271,6 +280,7 @@ def is_loaded() -> bool:
global _lib global _lib
return _lib is not None return _lib is not None
class OpusError(DiscordException): class OpusError(DiscordException):
"""An exception that is thrown for libopus related errors. """An exception that is thrown for libopus related errors.
@ -282,19 +292,22 @@ class OpusError(DiscordException):
def __init__(self, code: int): def __init__(self, code: int):
self.code: int = code self.code: int = code
msg = _lib.opus_strerror(self.code).decode('utf-8') msg = _lib.opus_strerror(self.code).decode("utf-8")
_log.info('"%s" has happened', msg) _log.info('"%s" has happened', msg)
super().__init__(msg) super().__init__(msg)
class OpusNotLoaded(DiscordException): class OpusNotLoaded(DiscordException):
"""An exception that is thrown for when libopus is not loaded.""" """An exception that is thrown for when libopus is not loaded."""
pass pass
class _OpusStruct: class _OpusStruct:
SAMPLING_RATE = 48000 SAMPLING_RATE = 48000
CHANNELS = 2 CHANNELS = 2
FRAME_LENGTH = 20 # in milliseconds FRAME_LENGTH = 20 # in milliseconds
SAMPLE_SIZE = struct.calcsize('h') * CHANNELS SAMPLE_SIZE = struct.calcsize("h") * CHANNELS
SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH) SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH)
FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE
@ -304,7 +317,8 @@ class _OpusStruct:
if not is_loaded() and not _load_default(): if not is_loaded() and not _load_default():
raise OpusNotLoaded() raise OpusNotLoaded()
return _lib.opus_get_version_string().decode('utf-8') return _lib.opus_get_version_string().decode("utf-8")
class Encoder(_OpusStruct): class Encoder(_OpusStruct):
def __init__(self, application: int = APPLICATION_AUDIO): def __init__(self, application: int = APPLICATION_AUDIO):
@ -315,14 +329,14 @@ class Encoder(_OpusStruct):
self.set_bitrate(128) self.set_bitrate(128)
self.set_fec(True) self.set_fec(True)
self.set_expected_packet_loss_percent(0.15) self.set_expected_packet_loss_percent(0.15)
self.set_bandwidth('full') self.set_bandwidth("full")
self.set_signal_type('auto') self.set_signal_type("auto")
def __del__(self) -> None: def __del__(self) -> None:
if hasattr(self, '_state'): if hasattr(self, "_state"):
_lib.opus_encoder_destroy(self._state) _lib.opus_encoder_destroy(self._state)
# This is a destructor, so it's okay to assign None # This is a destructor, so it's okay to assign None
self._state = None # type: ignore self._state = None # type: ignore
def _create_state(self) -> EncoderStruct: def _create_state(self) -> EncoderStruct:
ret = ctypes.c_int() ret = ctypes.c_int()
@ -352,18 +366,19 @@ class Encoder(_OpusStruct):
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage: float) -> None: def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
def encode(self, pcm: bytes, frame_size: int) -> bytes: def encode(self, pcm: bytes, frame_size: int) -> bytes:
max_data_bytes = len(pcm) max_data_bytes = len(pcm)
# bytes can be used to reference pointer # bytes can be used to reference pointer
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
data = (ctypes.c_char * max_data_bytes)() data = (ctypes.c_char * max_data_bytes)()
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know # array can be initialized with bytes but mypy doesn't know
return array.array('b', data[:ret]).tobytes() # type: ignore return array.array("b", data[:ret]).tobytes() # type: ignore
class Decoder(_OpusStruct): class Decoder(_OpusStruct):
def __init__(self): def __init__(self):
@ -372,10 +387,10 @@ class Decoder(_OpusStruct):
self._state: DecoderStruct = self._create_state() self._state: DecoderStruct = self._create_state()
def __del__(self) -> None: def __del__(self) -> None:
if hasattr(self, '_state'): if hasattr(self, "_state"):
_lib.opus_decoder_destroy(self._state) _lib.opus_decoder_destroy(self._state)
# This is a destructor, so it's okay to assign None # This is a destructor, so it's okay to assign None
self._state = None # type: ignore self._state = None # type: ignore
def _create_state(self) -> DecoderStruct: def _create_state(self) -> DecoderStruct:
ret = ctypes.c_int() ret = ctypes.c_int()
@ -411,12 +426,12 @@ class Decoder(_OpusStruct):
def set_gain(self, dB: float) -> int: def set_gain(self, dB: float) -> int:
"""Sets the decoder gain in dB, from -128 to 128.""" """Sets the decoder gain in dB, from -128 to 128."""
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
return self._set_gain(dB_Q8) return self._set_gain(dB_Q8)
def set_volume(self, mult: float) -> int: def set_volume(self, mult: float) -> int:
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc.""" """Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
return self.set_gain(20 * math.log10(mult)) # amplitude ratio return self.set_gain(20 * math.log10(mult)) # amplitude ratio
def _get_last_packet_duration(self) -> int: def _get_last_packet_duration(self) -> int:
"""Gets the duration (in samples) of the last packet successfully decoded or concealed.""" """Gets the duration (in samples) of the last packet successfully decoded or concealed."""
@ -451,4 +466,4 @@ class Decoder(_OpusStruct):
ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec) ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec)
return array.array('h', pcm[:ret * channel_count]).tobytes() return array.array("h", pcm[: ret * channel_count]).tobytes()

View File

@ -31,15 +31,14 @@ from .asset import Asset, AssetMixin
from .errors import InvalidArgument from .errors import InvalidArgument
from . import utils from . import utils
__all__ = ( __all__ = ("PartialEmoji",)
'PartialEmoji',
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .state import ConnectionState from .state import ConnectionState
from datetime import datetime from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload from .types.message import PartialEmoji as PartialEmojiPayload
class _EmojiTag: class _EmojiTag:
__slots__ = () __slots__ = ()
@ -49,7 +48,7 @@ class _EmojiTag:
raise NotImplementedError raise NotImplementedError
PE = TypeVar('PE', bound='PartialEmoji') PE = TypeVar("PE", bound="PartialEmoji")
class PartialEmoji(_EmojiTag, AssetMixin): class PartialEmoji(_EmojiTag, AssetMixin):
@ -90,9 +89,9 @@ class PartialEmoji(_EmojiTag, AssetMixin):
The ID of the custom emoji, if applicable. The ID of the custom emoji, if applicable.
""" """
__slots__ = ('animated', 'name', 'id', '_state') __slots__ = ("animated", "name", "id", "_state")
_CUSTOM_EMOJI_RE = re.compile(r'<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?') _CUSTOM_EMOJI_RE = re.compile(r"<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?")
if TYPE_CHECKING: if TYPE_CHECKING:
id: Optional[int] id: Optional[int]
@ -106,9 +105,9 @@ class PartialEmoji(_EmojiTag, AssetMixin):
@classmethod @classmethod
def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE: def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE:
return cls( return cls(
animated=data.get('animated', False), animated=data.get("animated", False),
id=utils._get_as_snowflake(data, 'id'), id=utils._get_as_snowflake(data, "id"),
name=data.get('name') or '', name=data.get("name") or "",
) )
@classmethod @classmethod
@ -139,19 +138,19 @@ class PartialEmoji(_EmojiTag, AssetMixin):
match = cls._CUSTOM_EMOJI_RE.match(value) match = cls._CUSTOM_EMOJI_RE.match(value)
if match is not None: if match is not None:
groups = match.groupdict() groups = match.groupdict()
animated = bool(groups['animated']) animated = bool(groups["animated"])
emoji_id = int(groups['id']) emoji_id = int(groups["id"])
name = groups['name'] name = groups["name"]
return cls(name=name, animated=animated, id=emoji_id) return cls(name=name, animated=animated, id=emoji_id)
return cls(name=value, id=None, animated=False) return cls(name=value, id=None, animated=False)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
o: Dict[str, Any] = {'name': self.name} o: Dict[str, Any] = {"name": self.name}
if self.id: if self.id:
o['id'] = self.id o["id"] = self.id
if self.animated: if self.animated:
o['animated'] = self.animated o["animated"] = self.animated
return o return o
def _to_partial(self) -> PartialEmoji: def _to_partial(self) -> PartialEmoji:
@ -169,11 +168,11 @@ class PartialEmoji(_EmojiTag, AssetMixin):
if self.id is None: if self.id is None:
return self.name return self.name
if self.animated: if self.animated:
return f'<a:{self.name}:{self.id}>' return f"<a:{self.name}:{self.id}>"
return f'<:{self.name}:{self.id}>' return f"<:{self.name}:{self.id}>"
def __repr__(self): def __repr__(self):
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>' return f"<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>"
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
if self.is_unicode_emoji(): if self.is_unicode_emoji():
@ -200,7 +199,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
def _as_reaction(self) -> str: def _as_reaction(self) -> str:
if self.id is None: if self.id is None:
return self.name return self.name
return f'{self.name}:{self.id}' return f"{self.name}:{self.id}"
@property @property
def created_at(self) -> Optional[datetime]: def created_at(self) -> Optional[datetime]:
@ -220,13 +219,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
If this isn't a custom emoji then an empty string is returned If this isn't a custom emoji then an empty string is returned
""" """
if self.is_unicode_emoji(): if self.is_unicode_emoji():
return '' return ""
fmt = 'gif' if self.animated else 'png' fmt = "gif" if self.animated else "png"
return f'{Asset.BASE}/emojis/{self.id}.{fmt}' return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
async def read(self) -> bytes: async def read(self) -> bytes:
if self.is_unicode_emoji(): if self.is_unicode_emoji():
raise InvalidArgument('PartialEmoji is not a custom emoji') raise InvalidArgument("PartialEmoji is not a custom emoji")
return await super().read() return await super().read()

View File

@ -28,8 +28,8 @@ from typing import Callable, Any, ClassVar, Dict, Iterator, Set, TYPE_CHECKING,
from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value
__all__ = ( __all__ = (
'Permissions', "Permissions",
'PermissionOverwrite', "PermissionOverwrite",
) )
# A permission alias works like a regular flag but is marked # A permission alias works like a regular flag but is marked
@ -46,7 +46,9 @@ def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permis
return decorator return decorator
P = TypeVar('P', bound='Permissions')
P = TypeVar("P", bound="Permissions")
@fill_with_flags() @fill_with_flags()
class Permissions(BaseFlags): class Permissions(BaseFlags):
@ -101,12 +103,12 @@ class Permissions(BaseFlags):
def __init__(self, permissions: int = 0, **kwargs: bool): def __init__(self, permissions: int = 0, **kwargs: bool):
if not isinstance(permissions, int): if not isinstance(permissions, int):
raise TypeError(f'Expected int parameter, received {permissions.__class__.__name__} instead.') raise TypeError(f"Expected int parameter, received {permissions.__class__.__name__} instead.")
self.value = permissions self.value = permissions
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_FLAGS: if key not in self.VALID_FLAGS:
raise TypeError(f'{key!r} is not a valid permission name.') raise TypeError(f"{key!r} is not a valid permission name.")
setattr(self, key, value) setattr(self, key, value)
def is_subset(self, other: Permissions) -> bool: def is_subset(self, other: Permissions) -> bool:
@ -299,6 +301,13 @@ class Permissions(BaseFlags):
""" """
return 1 << 3 return 1 << 3
@make_permission_alias("administrator")
def admin(self) -> int:
""":class:`bool`: An alias for :attr:`administrator`.
.. versionadded:: 2.0
"""
return 1 << 3
@flag_value @flag_value
def manage_channels(self) -> int: def manage_channels(self) -> int:
""":class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild. """:class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild.
@ -336,7 +345,7 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels.""" """:class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels."""
return 1 << 10 return 1 << 10
@make_permission_alias('read_messages') @make_permission_alias("read_messages")
def view_channel(self) -> int: def view_channel(self) -> int:
""":class:`bool`: An alias for :attr:`read_messages`. """:class:`bool`: An alias for :attr:`read_messages`.
@ -389,7 +398,7 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can use emojis from other guilds.""" """:class:`bool`: Returns ``True`` if a user can use emojis from other guilds."""
return 1 << 18 return 1 << 18
@make_permission_alias('external_emojis') @make_permission_alias("external_emojis")
def use_external_emojis(self) -> int: def use_external_emojis(self) -> int:
""":class:`bool`: An alias for :attr:`external_emojis`. """:class:`bool`: An alias for :attr:`external_emojis`.
@ -453,7 +462,7 @@ class Permissions(BaseFlags):
""" """
return 1 << 28 return 1 << 28
@make_permission_alias('manage_roles') @make_permission_alias("manage_roles")
def manage_permissions(self) -> int: def manage_permissions(self) -> int:
""":class:`bool`: An alias for :attr:`manage_roles`. """:class:`bool`: An alias for :attr:`manage_roles`.
@ -471,7 +480,7 @@ class Permissions(BaseFlags):
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis.""" """:class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis."""
return 1 << 30 return 1 << 30
@make_permission_alias('manage_emojis') @make_permission_alias("manage_emojis")
def manage_emojis_and_stickers(self) -> int: def manage_emojis_and_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`manage_emojis`. """:class:`bool`: An alias for :attr:`manage_emojis`.
@ -535,7 +544,7 @@ class Permissions(BaseFlags):
""" """
return 1 << 37 return 1 << 37
@make_permission_alias('external_stickers') @make_permission_alias("external_stickers")
def use_external_stickers(self) -> int: def use_external_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`external_stickers`. """:class:`bool`: An alias for :attr:`external_stickers`.
@ -551,7 +560,9 @@ class Permissions(BaseFlags):
""" """
return 1 << 38 return 1 << 38
PO = TypeVar('PO', bound='PermissionOverwrite')
PO = TypeVar("PO", bound="PermissionOverwrite")
def _augment_from_permissions(cls): def _augment_from_permissions(cls):
cls.VALID_NAMES = set(Permissions.VALID_FLAGS) cls.VALID_NAMES = set(Permissions.VALID_FLAGS)
@ -614,7 +625,7 @@ class PermissionOverwrite:
Set the value of permissions by their name. Set the value of permissions by their name.
""" """
__slots__ = ('_values',) __slots__ = ("_values",)
if TYPE_CHECKING: if TYPE_CHECKING:
VALID_NAMES: ClassVar[Set[str]] VALID_NAMES: ClassVar[Set[str]]
@ -670,7 +681,7 @@ class PermissionOverwrite:
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_NAMES: if key not in self.VALID_NAMES:
raise ValueError(f'no permission called {key}.') raise ValueError(f"no permission called {key}.")
setattr(self, key, value) setattr(self, key, value)
@ -679,7 +690,7 @@ class PermissionOverwrite:
def _set(self, key: str, value: Optional[bool]) -> None: def _set(self, key: str, value: Optional[bool]) -> None:
if value not in (True, None, False): if value not in (True, None, False):
raise TypeError(f'Expected bool or NoneType, received {value.__class__.__name__}') raise TypeError(f"Expected bool or NoneType, received {value.__class__.__name__}")
if value is None: if value is None:
self._values.pop(key, None) self._values.pop(key, None)

View File

@ -36,7 +36,7 @@ import sys
import re import re
import io import io
from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from .errors import ClientException from .errors import ClientException
from .opus import Encoder as OpusEncoder from .opus import Encoder as OpusEncoder
@ -47,27 +47,28 @@ if TYPE_CHECKING:
from .voice_client import VoiceClient from .voice_client import VoiceClient
AT = TypeVar('AT', bound='AudioSource') AT = TypeVar("AT", bound="AudioSource")
FT = TypeVar('FT', bound='FFmpegOpusAudio') FT = TypeVar("FT", bound="FFmpegOpusAudio")
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
__all__ = ( __all__ = (
'AudioSource', "AudioSource",
'PCMAudio', "PCMAudio",
'FFmpegAudio', "FFmpegAudio",
'FFmpegPCMAudio', "FFmpegPCMAudio",
'FFmpegOpusAudio', "FFmpegOpusAudio",
'PCMVolumeTransformer', "PCMVolumeTransformer",
) )
CREATE_NO_WINDOW: int CREATE_NO_WINDOW: int
if sys.platform != 'win32': if sys.platform != "win32":
CREATE_NO_WINDOW = 0 CREATE_NO_WINDOW = 0
else: else:
CREATE_NO_WINDOW = 0x08000000 CREATE_NO_WINDOW = 0x08000000
class AudioSource: class AudioSource:
"""Represents an audio stream. """Represents an audio stream.
@ -114,6 +115,7 @@ class AudioSource:
def __del__(self) -> None: def __del__(self) -> None:
self.cleanup() self.cleanup()
class PCMAudio(AudioSource): class PCMAudio(AudioSource):
"""Represents raw 16-bit 48KHz stereo PCM audio source. """Represents raw 16-bit 48KHz stereo PCM audio source.
@ -122,15 +124,17 @@ class PCMAudio(AudioSource):
stream: :term:`py:file object` stream: :term:`py:file object`
A file-like object that reads byte data representing raw PCM. A file-like object that reads byte data representing raw PCM.
""" """
def __init__(self, stream: io.BufferedIOBase) -> None: def __init__(self, stream: io.BufferedIOBase) -> None:
self.stream: io.BufferedIOBase = stream self.stream: io.BufferedIOBase = stream
def read(self) -> bytes: def read(self) -> bytes:
ret = self.stream.read(OpusEncoder.FRAME_SIZE) ret = self.stream.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE: if len(ret) != OpusEncoder.FRAME_SIZE:
return b'' return b""
return ret return ret
class FFmpegAudio(AudioSource): class FFmpegAudio(AudioSource):
"""Represents an FFmpeg (or AVConv) based AudioSource. """Represents an FFmpeg (or AVConv) based AudioSource.
@ -140,13 +144,15 @@ class FFmpegAudio(AudioSource):
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any): def __init__(
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE self, source: Union[str, io.BufferedIOBase], *, executable: str = "ffmpeg", args: Any, **subprocess_kwargs: Any
):
piping = subprocess_kwargs.get("stdin") == subprocess.PIPE
if piping and isinstance(source, str): if piping and isinstance(source, str):
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin") raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
args = [executable, *args] args = [executable, *args]
kwargs = {'stdout': subprocess.PIPE} kwargs = {"stdout": subprocess.PIPE}
kwargs.update(subprocess_kwargs) kwargs.update(subprocess_kwargs)
self._process: subprocess.Popen = self._spawn_process(args, **kwargs) self._process: subprocess.Popen = self._spawn_process(args, **kwargs)
@ -155,7 +161,7 @@ class FFmpegAudio(AudioSource):
self._pipe_thread: Optional[threading.Thread] = None self._pipe_thread: Optional[threading.Thread] = None
if piping: if piping:
n = f'popen-stdin-writer:{id(self):#x}' n = f"popen-stdin-writer:{id(self):#x}"
self._stdin = self._process.stdin self._stdin = self._process.stdin
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n) self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
self._pipe_thread.start() self._pipe_thread.start()
@ -165,10 +171,10 @@ class FFmpegAudio(AudioSource):
try: try:
process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs)
except FileNotFoundError: except FileNotFoundError:
executable = args.partition(' ')[0] if isinstance(args, str) else args[0] executable = args.partition(" ")[0] if isinstance(args, str) else args[0]
raise ClientException(executable + ' was not found.') from None raise ClientException(executable + " was not found.") from None
except subprocess.SubprocessError as exc: except subprocess.SubprocessError as exc:
raise ClientException(f'Popen failed: {exc.__class__.__name__}: {exc}') from exc raise ClientException(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc
else: else:
return process return process
@ -177,20 +183,19 @@ class FFmpegAudio(AudioSource):
if proc is MISSING: if proc is MISSING:
return return
_log.info('Preparing to terminate ffmpeg process %s.', proc.pid) _log.info("Preparing to terminate ffmpeg process %s.", proc.pid)
try: try:
proc.kill() proc.kill()
except Exception: except Exception:
_log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid) _log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid)
if proc.poll() is None: if proc.poll() is None:
_log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid) _log.info("ffmpeg process %s has not terminated. Waiting to terminate...", proc.pid)
proc.communicate() proc.communicate()
_log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode) _log.info("ffmpeg process %s should have terminated with a return code of %s.", proc.pid, proc.returncode)
else: else:
_log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode) _log.info("ffmpeg process %s successfully terminated with return code of %s.", proc.pid, proc.returncode)
def _pipe_writer(self, source: io.BufferedIOBase) -> None: def _pipe_writer(self, source: io.BufferedIOBase) -> None:
while self._process: while self._process:
@ -202,7 +207,7 @@ class FFmpegAudio(AudioSource):
try: try:
self._stdin.write(data) self._stdin.write(data)
except Exception: except Exception:
_log.debug('Write error for %s, this is probably not a problem', self, exc_info=True) _log.debug("Write error for %s, this is probably not a problem", self, exc_info=True)
# at this point the source data is either exhausted or the process is fubar # at this point the source data is either exhausted or the process is fubar
self._process.terminate() self._process.terminate()
return return
@ -211,6 +216,7 @@ class FFmpegAudio(AudioSource):
self._kill_process() self._kill_process()
self._process = self._stdout = self._stdin = MISSING self._process = self._stdout = self._stdin = MISSING
class FFmpegPCMAudio(FFmpegAudio): class FFmpegPCMAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv). """An audio source from FFmpeg (or AVConv).
@ -250,38 +256,39 @@ class FFmpegPCMAudio(FFmpegAudio):
self, self,
source: Union[str, io.BufferedIOBase], source: Union[str, io.BufferedIOBase],
*, *,
executable: str = 'ffmpeg', executable: str = "ffmpeg",
pipe: bool = False, pipe: bool = False,
stderr: Optional[IO[str]] = None, stderr: Optional[IO[str]] = None,
before_options: Optional[str] = None, before_options: Optional[str] = None,
options: Optional[str] = None options: Optional[str] = None,
) -> None: ) -> None:
args = [] args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr}
if isinstance(before_options, str): if isinstance(before_options, str):
args.extend(shlex.split(before_options)) args.extend(shlex.split(before_options))
args.append('-i') args.append("-i")
args.append('-' if pipe else source) args.append("-" if pipe else source)
args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning')) args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning"))
if isinstance(options, str): if isinstance(options, str):
args.extend(shlex.split(options)) args.extend(shlex.split(options))
args.append('pipe:1') args.append("pipe:1")
super().__init__(source, executable=executable, args=args, **subprocess_kwargs) super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
def read(self) -> bytes: def read(self) -> bytes:
ret = self._stdout.read(OpusEncoder.FRAME_SIZE) ret = self._stdout.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE: if len(ret) != OpusEncoder.FRAME_SIZE:
return b'' return b""
return ret return ret
def is_opus(self) -> bool: def is_opus(self) -> bool:
return False return False
class FFmpegOpusAudio(FFmpegAudio): class FFmpegOpusAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv). """An audio source from FFmpeg (or AVConv).
@ -349,7 +356,7 @@ class FFmpegOpusAudio(FFmpegAudio):
*, *,
bitrate: int = 128, bitrate: int = 128,
codec: Optional[str] = None, codec: Optional[str] = None,
executable: str = 'ffmpeg', executable: str = "ffmpeg",
pipe=False, pipe=False,
stderr=None, stderr=None,
before_options=None, before_options=None,
@ -357,28 +364,39 @@ class FFmpegOpusAudio(FFmpegAudio):
) -> None: ) -> None:
args = [] args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr}
if isinstance(before_options, str): if isinstance(before_options, str):
args.extend(shlex.split(before_options)) args.extend(shlex.split(before_options))
args.append('-i') args.append("-i")
args.append('-' if pipe else source) args.append("-" if pipe else source)
codec = 'copy' if codec in ('opus', 'libopus') else 'libopus' codec = "copy" if codec in ("opus", "libopus") else "libopus"
args.extend(('-map_metadata', '-1', args.extend(
'-f', 'opus', (
'-c:a', codec, "-map_metadata",
'-ar', '48000', "-1",
'-ac', '2', "-f",
'-b:a', f'{bitrate}k', "opus",
'-loglevel', 'warning')) "-c:a",
codec,
"-ar",
"48000",
"-ac",
"2",
"-b:a",
f"{bitrate}k",
"-loglevel",
"warning",
)
)
if isinstance(options, str): if isinstance(options, str):
args.extend(shlex.split(options)) args.extend(shlex.split(options))
args.append('pipe:1') args.append("pipe:1")
super().__init__(source, executable=executable, args=args, **subprocess_kwargs) super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
self._packet_iter = OggStream(self._stdout).iter_packets() self._packet_iter = OggStream(self._stdout).iter_packets()
@ -446,7 +464,7 @@ class FFmpegOpusAudio(FFmpegAudio):
An instance of this class. An instance of this class.
""" """
executable = kwargs.get('executable') executable = kwargs.get("executable")
codec, bitrate = await cls.probe(source, method=method, executable=executable) codec, bitrate = await cls.probe(source, method=method, executable=executable)
return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore
@ -484,12 +502,12 @@ class FFmpegOpusAudio(FFmpegAudio):
A 2-tuple with the codec and bitrate of the input source. A 2-tuple with the codec and bitrate of the input source.
""" """
method = method or 'native' method = method or "native"
executable = executable or 'ffmpeg' executable = executable or "ffmpeg"
probefunc = fallback = None probefunc = fallback = None
if isinstance(method, str): if isinstance(method, str):
probefunc = getattr(cls, '_probe_codec_' + method, None) probefunc = getattr(cls, "_probe_codec_" + method, None)
if probefunc is None: if probefunc is None:
raise AttributeError(f"Invalid probe method {method!r}") raise AttributeError(f"Invalid probe method {method!r}")
@ -500,8 +518,7 @@ class FFmpegOpusAudio(FFmpegAudio):
probefunc = method probefunc = method
fallback = cls._probe_codec_fallback fallback = cls._probe_codec_fallback
else: else:
raise TypeError("Expected str or callable for parameter 'probe', " \ raise TypeError("Expected str or callable for parameter 'probe', " f"not '{method.__class__.__name__}'")
f"not '{method.__class__.__name__}'")
codec = bitrate = None codec = bitrate = None
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -525,28 +542,28 @@ class FFmpegOpusAudio(FFmpegAudio):
return codec, bitrate return codec, bitrate
@staticmethod @staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: def _probe_codec_native(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]:
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable exe = executable[:2] + "probe" if executable in ("ffmpeg", "avconv") else executable
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] args = [exe, "-v", "quiet", "-print_format", "json", "-show_streams", "-select_streams", "a:0", source]
output = subprocess.check_output(args, timeout=20) output = subprocess.check_output(args, timeout=20)
codec = bitrate = None codec = bitrate = None
if output: if output:
data = json.loads(output) data = json.loads(output)
streamdata = data['streams'][0] streamdata = data["streams"][0]
codec = streamdata.get('codec_name') codec = streamdata.get("codec_name")
bitrate = int(streamdata.get('bit_rate', 0)) bitrate = int(streamdata.get("bit_rate", 0))
bitrate = max(round(bitrate/1000), 512) bitrate = max(round(bitrate / 1000), 512)
return codec, bitrate return codec, bitrate
@staticmethod @staticmethod
def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: def _probe_codec_fallback(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]:
args = [executable, '-hide_banner', '-i', source] args = [executable, "-hide_banner", "-i", source]
proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
out, _ = proc.communicate(timeout=20) out, _ = proc.communicate(timeout=20)
output = out.decode('utf8') output = out.decode("utf8")
codec = bitrate = None codec = bitrate = None
codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output) codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output)
@ -560,11 +577,12 @@ class FFmpegOpusAudio(FFmpegAudio):
return codec, bitrate return codec, bitrate
def read(self) -> bytes: def read(self) -> bytes:
return next(self._packet_iter, b'') return next(self._packet_iter, b"")
def is_opus(self) -> bool: def is_opus(self) -> bool:
return True return True
class PCMVolumeTransformer(AudioSource, Generic[AT]): class PCMVolumeTransformer(AudioSource, Generic[AT]):
"""Transforms a previous :class:`AudioSource` to have volume controls. """Transforms a previous :class:`AudioSource` to have volume controls.
@ -589,10 +607,10 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
def __init__(self, original: AT, volume: float = 1.0): def __init__(self, original: AT, volume: float = 1.0):
if not isinstance(original, AudioSource): if not isinstance(original, AudioSource):
raise TypeError(f'expected AudioSource not {original.__class__.__name__}.') raise TypeError(f"expected AudioSource not {original.__class__.__name__}.")
if original.is_opus(): if original.is_opus():
raise ClientException('AudioSource must not be Opus encoded.') raise ClientException("AudioSource must not be Opus encoded.")
self.original: AT = original self.original: AT = original
self.volume = volume self.volume = volume
@ -613,6 +631,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
ret = self.original.read() ret = self.original.read()
return audioop.mul(ret, 2, min(self._volume, 2.0)) return audioop.mul(ret, 2, min(self._volume, 2.0))
class AudioPlayer(threading.Thread): class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
@ -625,7 +644,7 @@ class AudioPlayer(threading.Thread):
self._end: threading.Event = threading.Event() self._end: threading.Event = threading.Event()
self._resumed: threading.Event = threading.Event() self._resumed: threading.Event = threading.Event()
self._resumed.set() # we are not paused self._resumed.set() # we are not paused
self._current_error: Optional[Exception] = None self._current_error: Optional[Exception] = None
self._connected: threading.Event = client._connected self._connected: threading.Event = client._connected
self._lock: threading.Lock = threading.Lock() self._lock: threading.Lock = threading.Lock()
@ -685,11 +704,11 @@ class AudioPlayer(threading.Thread):
try: try:
self.after(error) self.after(error)
except Exception as exc: except Exception as exc:
_log.exception('Calling the after function failed.') _log.exception("Calling the after function failed.")
exc.__context__ = error exc.__context__ = error
traceback.print_exception(type(exc), exc, exc.__traceback__) traceback.print_exception(type(exc), exc, exc.__traceback__)
elif error: elif error:
msg = f'Exception in voice thread {self.name}' msg = f"Exception in voice thread {self.name}"
_log.exception(msg, exc_info=error) _log.exception(msg, exc_info=error)
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
traceback.print_exception(type(error), error, error.__traceback__) traceback.print_exception(type(error), error, error.__traceback__)

View File

@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Optional, Set, List from typing import TYPE_CHECKING, Optional, Set, List
if TYPE_CHECKING: if TYPE_CHECKING:
@ -34,7 +35,8 @@ if TYPE_CHECKING:
MessageUpdateEvent, MessageUpdateEvent,
ReactionClearEvent, ReactionClearEvent,
ReactionClearEmojiEvent, ReactionClearEmojiEvent,
IntegrationDeleteEvent IntegrationDeleteEvent,
TypingEvent,
) )
from .message import Message from .message import Message
from .partial_emoji import PartialEmoji from .partial_emoji import PartialEmoji
@ -42,20 +44,21 @@ if TYPE_CHECKING:
__all__ = ( __all__ = (
'RawMessageDeleteEvent', "RawMessageDeleteEvent",
'RawBulkMessageDeleteEvent', "RawBulkMessageDeleteEvent",
'RawMessageUpdateEvent', "RawMessageUpdateEvent",
'RawReactionActionEvent', "RawReactionActionEvent",
'RawReactionClearEvent', "RawReactionClearEvent",
'RawReactionClearEmojiEvent', "RawReactionClearEmojiEvent",
'RawIntegrationDeleteEvent', "RawIntegrationDeleteEvent",
"RawTypingEvent",
) )
class _RawReprMixin: class _RawReprMixin:
def __repr__(self) -> str: def __repr__(self) -> str:
value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__) value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__)
return f'<{self.__class__.__name__} {value}>' return f"<{self.__class__.__name__} {value}>"
class RawMessageDeleteEvent(_RawReprMixin): class RawMessageDeleteEvent(_RawReprMixin):
@ -73,14 +76,14 @@ class RawMessageDeleteEvent(_RawReprMixin):
The cached message, if found in the internal message cache. The cached message, if found in the internal message cache.
""" """
__slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message') __slots__ = ("message_id", "channel_id", "guild_id", "cached_message")
def __init__(self, data: MessageDeleteEvent) -> None: def __init__(self, data: MessageDeleteEvent) -> None:
self.message_id: int = int(data['id']) self.message_id: int = int(data["id"])
self.channel_id: int = int(data['channel_id']) self.channel_id: int = int(data["channel_id"])
self.cached_message: Optional[Message] = None self.cached_message: Optional[Message] = None
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id: Optional[int] = None
@ -100,15 +103,15 @@ class RawBulkMessageDeleteEvent(_RawReprMixin):
The cached messages, if found in the internal message cache. The cached messages, if found in the internal message cache.
""" """
__slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages') __slots__ = ("message_ids", "channel_id", "guild_id", "cached_messages")
def __init__(self, data: BulkMessageDeleteEvent) -> None: def __init__(self, data: BulkMessageDeleteEvent) -> None:
self.message_ids: Set[int] = {int(x) for x in data.get('ids', [])} self.message_ids: Set[int] = {int(x) for x in data.get("ids", [])}
self.channel_id: int = int(data['channel_id']) self.channel_id: int = int(data["channel_id"])
self.cached_messages: List[Message] = [] self.cached_messages: List[Message] = []
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id: Optional[int] = None
@ -136,16 +139,16 @@ class RawMessageUpdateEvent(_RawReprMixin):
it is modified by the data in :attr:`RawMessageUpdateEvent.data`. it is modified by the data in :attr:`RawMessageUpdateEvent.data`.
""" """
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message') __slots__ = ("message_id", "channel_id", "guild_id", "data", "cached_message")
def __init__(self, data: MessageUpdateEvent) -> None: def __init__(self, data: MessageUpdateEvent) -> None:
self.message_id: int = int(data['id']) self.message_id: int = int(data["id"])
self.channel_id: int = int(data['channel_id']) self.channel_id: int = int(data["channel_id"])
self.data: MessageUpdateEvent = data self.data: MessageUpdateEvent = data
self.cached_message: Optional[Message] = None self.cached_message: Optional[Message] = None
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id: Optional[int] = None
@ -179,19 +182,18 @@ class RawReactionActionEvent(_RawReprMixin):
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', __slots__ = ("message_id", "user_id", "channel_id", "guild_id", "emoji", "event_type", "member")
'event_type', 'member')
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None: def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None:
self.message_id: int = int(data['message_id']) self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data['channel_id']) self.channel_id: int = int(data["channel_id"])
self.user_id: int = int(data['user_id']) self.user_id: int = int(data["user_id"])
self.emoji: PartialEmoji = emoji self.emoji: PartialEmoji = emoji
self.event_type: str = event_type self.event_type: str = event_type
self.member: Optional[Member] = None self.member: Optional[Member] = None
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id: Optional[int] = None
@ -209,14 +211,14 @@ class RawReactionClearEvent(_RawReprMixin):
The guild ID where the reactions got cleared. The guild ID where the reactions got cleared.
""" """
__slots__ = ('message_id', 'channel_id', 'guild_id') __slots__ = ("message_id", "channel_id", "guild_id")
def __init__(self, data: ReactionClearEvent) -> None: def __init__(self, data: ReactionClearEvent) -> None:
self.message_id: int = int(data['message_id']) self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data['channel_id']) self.channel_id: int = int(data["channel_id"])
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id: Optional[int] = None
@ -238,15 +240,15 @@ class RawReactionClearEmojiEvent(_RawReprMixin):
The custom or unicode emoji being removed. The custom or unicode emoji being removed.
""" """
__slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji') __slots__ = ("message_id", "channel_id", "guild_id", "emoji")
def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None: def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None:
self.emoji: PartialEmoji = emoji self.emoji: PartialEmoji = emoji
self.message_id: int = int(data['message_id']) self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data['channel_id']) self.channel_id: int = int(data["channel_id"])
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id: Optional[int] = None
@ -266,13 +268,46 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
The guild ID where the integration got deleted. The guild ID where the integration got deleted.
""" """
__slots__ = ('integration_id', 'application_id', 'guild_id') __slots__ = ("integration_id", "application_id", "guild_id")
def __init__(self, data: IntegrationDeleteEvent) -> None: def __init__(self, data: IntegrationDeleteEvent) -> None:
self.integration_id: int = int(data['id']) self.integration_id: int = int(data["id"])
self.guild_id: int = int(data['guild_id']) self.guild_id: int = int(data["guild_id"])
try: try:
self.application_id: Optional[int] = int(data['application_id']) self.application_id: Optional[int] = int(data["application_id"])
except KeyError: except KeyError:
self.application_id: Optional[int] = None self.application_id: Optional[int] = None
class RawTypingEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_typing` event.
.. versionadded:: 2.0
Attributes
-----------
channel_id: :class:`int`
The channel ID where the typing originated from.
user_id: :class:`int`
The ID of the user that started typing.
when: :class:`datetime.datetime`
When the typing started as an aware datetime in UTC.
guild_id: Optional[:class:`int`]
The guild ID where the typing originated from, if applicable.
member: Optional[:class:`Member`]
The member who started typing. Only available if the member started typing in a guild.
"""
__slots__ = ("channel_id", "user_id", "when", "guild_id", "member")
def __init__(self, data: TypingEvent) -> None:
self.channel_id: int = int(data["channel_id"])
self.user_id: int = int(data["user_id"])
self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get("timestamp"), tz=datetime.timezone.utc)
self.member: Optional[Member] = None
try:
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None

View File

@ -27,9 +27,7 @@ from typing import Any, TYPE_CHECKING, Union, Optional
from .iterators import ReactionIterator from .iterators import ReactionIterator
__all__ = ( __all__ = ("Reaction",)
'Reaction',
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.message import Reaction as ReactionPayload from .types.message import Reaction as ReactionPayload
@ -38,6 +36,7 @@ if TYPE_CHECKING:
from .emoji import Emoji from .emoji import Emoji
from .abc import Snowflake from .abc import Snowflake
class Reaction: class Reaction:
"""Represents a reaction to a message. """Represents a reaction to a message.
@ -75,13 +74,16 @@ class Reaction:
message: :class:`Message` message: :class:`Message`
Message this reaction is for. Message this reaction is for.
""" """
__slots__ = ('message', 'count', 'emoji', 'me')
def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None): __slots__ = ("message", "count", "emoji", "me")
def __init__(
self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None
):
self.message: Message = message self.message: Message = message
self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji']) self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data["emoji"])
self.count: int = data.get('count', 1) self.count: int = data.get("count", 1)
self.me: bool = data.get('me') self.me: bool = data.get("me")
# TODO: typeguard # TODO: typeguard
def is_custom_emoji(self) -> bool: def is_custom_emoji(self) -> bool:
@ -103,7 +105,7 @@ class Reaction:
return str(self.emoji) return str(self.emoji)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>' return f"<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>"
async def remove(self, user: Snowflake) -> None: async def remove(self, user: Snowflake) -> None:
"""|coro| """|coro|
@ -201,7 +203,7 @@ class Reaction:
""" """
if not isinstance(self.emoji, str): if not isinstance(self.emoji, str):
emoji = f'{self.emoji.name}:{self.emoji.id}' emoji = f"{self.emoji.name}:{self.emoji.id}"
else: else:
emoji = self.emoji emoji = self.emoji

View File

@ -32,8 +32,8 @@ from .mixins import Hashable
from .utils import snowflake_time, _get_as_snowflake, MISSING from .utils import snowflake_time, _get_as_snowflake, MISSING
__all__ = ( __all__ = (
'RoleTags', "RoleTags",
'Role', "Role",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -68,19 +68,19 @@ class RoleTags:
""" """
__slots__ = ( __slots__ = (
'bot_id', "bot_id",
'integration_id', "integration_id",
'_premium_subscriber', "_premium_subscriber",
) )
def __init__(self, data: RoleTagPayload): def __init__(self, data: RoleTagPayload):
self.bot_id: Optional[int] = _get_as_snowflake(data, 'bot_id') self.bot_id: Optional[int] = _get_as_snowflake(data, "bot_id")
self.integration_id: Optional[int] = _get_as_snowflake(data, 'integration_id') self.integration_id: Optional[int] = _get_as_snowflake(data, "integration_id")
# NOTE: The API returns "null" for this if it's valid, which corresponds to None. # NOTE: The API returns "null" for this if it's valid, which corresponds to None.
# This is different from other fields where "null" means "not there". # This is different from other fields where "null" means "not there".
# So in this case, a value of None is the same as True. # So in this case, a value of None is the same as True.
# Which means we would need a different sentinel. # Which means we would need a different sentinel.
self._premium_subscriber: Optional[Any] = data.get('premium_subscriber', MISSING) self._premium_subscriber: Optional[Any] = data.get("premium_subscriber", MISSING)
def is_bot_managed(self) -> bool: def is_bot_managed(self) -> bool:
""":class:`bool`: Whether the role is associated with a bot.""" """:class:`bool`: Whether the role is associated with a bot."""
@ -96,12 +96,12 @@ class RoleTags:
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} ' f"<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} "
f'premium_subscriber={self.is_premium_subscriber()}>' f"premium_subscriber={self.is_premium_subscriber()}>"
) )
R = TypeVar('R', bound='Role') R = TypeVar("R", bound="Role")
class Role(Hashable): class Role(Hashable):
@ -141,6 +141,14 @@ class Role(Hashable):
Returns the role's name. Returns the role's name.
.. describe:: str(x)
Returns the role's ID.
.. describe:: int(x)
Returns the role's ID.
Attributes Attributes
---------- ----------
id: :class:`int` id: :class:`int`
@ -173,37 +181,40 @@ class Role(Hashable):
""" """
__slots__ = ( __slots__ = (
'id', "id",
'name', "name",
'_permissions', "_permissions",
'_colour', "_colour",
'position', "position",
'managed', "managed",
'mentionable', "mentionable",
'hoist', "hoist",
'guild', "guild",
'tags', "tags",
'_state', "_state",
) )
def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload): def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload):
self.guild: Guild = guild self.guild: Guild = guild
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self._update(data) self._update(data)
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
def __int__(self) -> int:
return self.id
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Role id={self.id} name={self.name!r}>' return f"<Role id={self.id} name={self.name!r}>"
def __lt__(self: R, other: R) -> bool: def __lt__(self: R, other: R) -> bool:
if not isinstance(other, Role) or not isinstance(self, Role): if not isinstance(other, Role) or not isinstance(self, Role):
return NotImplemented return NotImplemented
if self.guild != other.guild: if self.guild != other.guild:
raise RuntimeError('cannot compare roles from two different guilds.') raise RuntimeError("cannot compare roles from two different guilds.")
# the @everyone role is always the lowest role in hierarchy # the @everyone role is always the lowest role in hierarchy
guild_id = self.guild.id guild_id = self.guild.id
@ -235,17 +246,17 @@ class Role(Hashable):
return not r return not r
def _update(self, data: RolePayload): def _update(self, data: RolePayload):
self.name: str = data['name'] self.name: str = data["name"]
self._permissions: int = int(data.get('permissions', 0)) self._permissions: int = int(data.get("permissions", 0))
self.position: int = data.get('position', 0) self.position: int = data.get("position", 0)
self._colour: int = data.get('color', 0) self._colour: int = data.get("color", 0)
self.hoist: bool = data.get('hoist', False) self.hoist: bool = data.get("hoist", False)
self.managed: bool = data.get('managed', False) self.managed: bool = data.get("managed", False)
self.mentionable: bool = data.get('mentionable', False) self.mentionable: bool = data.get("mentionable", False)
self.tags: Optional[RoleTags] self.tags: Optional[RoleTags]
try: try:
self.tags = RoleTags(data['tags']) self.tags = RoleTags(data["tags"])
except KeyError: except KeyError:
self.tags = None self.tags = None
@ -305,7 +316,7 @@ class Role(Hashable):
@property @property
def mention(self) -> str: def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention a role.""" """:class:`str`: Returns a string that allows you to mention a role."""
return f'<@&{self.id}>' return f"<@&{self.id}>"
@property @property
def members(self) -> List[Member]: def members(self) -> List[Member]:
@ -409,21 +420,21 @@ class Role(Hashable):
if colour is not MISSING: if colour is not MISSING:
if isinstance(colour, int): if isinstance(colour, int):
payload['color'] = colour payload["color"] = colour
else: else:
payload['color'] = colour.value payload["color"] = colour.value
if name is not MISSING: if name is not MISSING:
payload['name'] = name payload["name"] = name
if permissions is not MISSING: if permissions is not MISSING:
payload['permissions'] = permissions.value payload["permissions"] = permissions.value
if hoist is not MISSING: if hoist is not MISSING:
payload['hoist'] = hoist payload["hoist"] = hoist
if mentionable is not MISSING: if mentionable is not MISSING:
payload['mentionable'] = mentionable payload["mentionable"] = mentionable
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload) data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
return Role(guild=self.guild, data=data, state=self._state) return Role(guild=self.guild, data=data, state=self._state)

View File

@ -50,11 +50,11 @@ if TYPE_CHECKING:
from .activity import BaseActivity from .activity import BaseActivity
from .enums import Status from .enums import Status
EI = TypeVar('EI', bound='EventItem') EI = TypeVar("EI", bound="EventItem")
__all__ = ( __all__ = (
'AutoShardedClient', "AutoShardedClient",
'ShardInfo', "ShardInfo",
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -70,11 +70,11 @@ class EventType:
class EventItem: class EventItem:
__slots__ = ('type', 'shard', 'error') __slots__ = ("type", "shard", "error")
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None: def __init__(self, etype: int, shard: Optional["Shard"], error: Optional[Exception]) -> None:
self.type: int = etype self.type: int = etype
self.shard: Optional['Shard'] = shard self.shard: Optional["Shard"] = shard
self.error: Optional[Exception] = error self.error: Optional[Exception] = error
def __lt__(self: EI, other: EI) -> bool: def __lt__(self: EI, other: EI) -> bool:
@ -129,11 +129,11 @@ class Shard:
async def disconnect(self) -> None: async def disconnect(self) -> None:
await self.close() await self.close()
self._dispatch('shard_disconnect', self.id) self._dispatch("shard_disconnect", self.id)
async def _handle_disconnect(self, e: Exception) -> None: async def _handle_disconnect(self, e: Exception) -> None:
self._dispatch('disconnect') self._dispatch("disconnect")
self._dispatch('shard_disconnect', self.id) self._dispatch("shard_disconnect", self.id)
if not self._reconnect: if not self._reconnect:
self._queue_put(EventItem(EventType.close, self, e)) self._queue_put(EventItem(EventType.close, self, e))
return return
@ -156,7 +156,7 @@ class Shard:
return return
retry = self._backoff.delay() retry = self._backoff.delay()
_log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e) _log.error("Attempting a reconnect for shard ID %s in %.2fs", self.id, retry, exc_info=e)
await asyncio.sleep(retry) await asyncio.sleep(retry)
self._queue_put(EventItem(EventType.reconnect, self, e)) self._queue_put(EventItem(EventType.reconnect, self, e))
@ -179,9 +179,9 @@ class Shard:
async def reidentify(self, exc: ReconnectWebSocket) -> None: async def reidentify(self, exc: ReconnectWebSocket) -> None:
self._cancel_task() self._cancel_task()
self._dispatch('disconnect') self._dispatch("disconnect")
self._dispatch('shard_disconnect', self.id) self._dispatch("shard_disconnect", self.id)
_log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) _log.info("Got a request to %s the websocket at Shard ID %s.", exc.op, self.id)
try: try:
coro = DiscordWebSocket.from_client( coro = DiscordWebSocket.from_client(
self._client, self._client,
@ -231,7 +231,7 @@ class ShardInfo:
The shard count for this cluster. If this is ``None`` then the bot has not started yet. The shard count for this cluster. If this is ``None`` then the bot has not started yet.
""" """
__slots__ = ('_parent', 'id', 'shard_count') __slots__ = ("_parent", "id", "shard_count")
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None: def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
self._parent: Shard = parent self._parent: Shard = parent
@ -321,15 +321,15 @@ class AutoShardedClient(Client):
_connection: AutoShardedConnectionState _connection: AutoShardedConnectionState
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None) kwargs.pop("shard_id", None)
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) self.shard_ids: Optional[List[int]] = kwargs.pop("shard_ids", None)
super().__init__(*args, loop=loop, **kwargs) super().__init__(*args, loop=loop, **kwargs)
if self.shard_ids is not None: if self.shard_ids is not None:
if self.shard_count is None: if self.shard_count is None:
raise ClientException('When passing manual shard_ids, you must provide a shard_count.') raise ClientException("When passing manual shard_ids, you must provide a shard_count.")
elif not isinstance(self.shard_ids, (list, tuple)): elif not isinstance(self.shard_ids, (list, tuple)):
raise ClientException('shard_ids parameter must be a list or a tuple.') raise ClientException("shard_ids parameter must be a list or a tuple.")
# instead of a single websocket, we have multiple # instead of a single websocket, we have multiple
# the key is the shard_id # the key is the shard_id
@ -363,7 +363,7 @@ class AutoShardedClient(Client):
:attr:`latencies` property. Returns ``nan`` if there are no shards ready. :attr:`latencies` property. Returns ``nan`` if there are no shards ready.
""" """
if not self.__shards: if not self.__shards:
return float('nan') return float("nan")
return sum(latency for _, latency in self.latencies) / len(self.__shards) return sum(latency for _, latency in self.latencies) / len(self.__shards)
@property @property
@ -393,7 +393,7 @@ class AutoShardedClient(Client):
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0) ws = await asyncio.wait_for(coro, timeout=180.0)
except Exception: except Exception:
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id) _log.exception("Failed to connect for shard_id: %s. Retrying...", shard_id)
await asyncio.sleep(5.0) await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id) return await self.launch_shard(gateway, shard_id)
@ -503,10 +503,10 @@ class AutoShardedClient(Client):
""" """
if status is None: if status is None:
status_value = 'online' status_value = "online"
status_enum = Status.online status_enum = Status.online
elif status is Status.offline: elif status is Status.offline:
status_value = 'invisible' status_value = "invisible"
status_enum = Status.offline status_enum = Status.offline
else: else:
status_enum = status status_enum = status

View File

@ -31,9 +31,7 @@ from .mixins import Hashable
from .errors import InvalidArgument from .errors import InvalidArgument
from .enums import StagePrivacyLevel, try_enum from .enums import StagePrivacyLevel, try_enum
__all__ = ( __all__ = ("StageInstance",)
'StageInstance',
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.channel import StageInstance as StageInstancePayload from .types.channel import StageInstance as StageInstancePayload
@ -61,6 +59,10 @@ class StageInstance(Hashable):
Returns the stage instance's hash. Returns the stage instance's hash.
.. describe:: int(x)
Returns the stage instance's ID.
Attributes Attributes
----------- -----------
id: :class:`int` id: :class:`int`
@ -78,14 +80,14 @@ class StageInstance(Hashable):
""" """
__slots__ = ( __slots__ = (
'_state', "_state",
'id', "id",
'guild', "guild",
'channel_id', "channel_id",
'topic', "topic",
'privacy_level', "privacy_level",
'discoverable_disabled', "discoverable_disabled",
'_cs_channel', "_cs_channel",
) )
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None: def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
@ -94,25 +96,27 @@ class StageInstance(Hashable):
self._update(data) self._update(data)
def _update(self, data: StageInstancePayload): def _update(self, data: StageInstancePayload):
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.channel_id: int = int(data['channel_id']) self.channel_id: int = int(data["channel_id"])
self.topic: str = data['topic'] self.topic: str = data["topic"]
self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data['privacy_level']) self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data["privacy_level"])
self.discoverable_disabled: bool = data.get('discoverable_disabled', False) self.discoverable_disabled: bool = data.get("discoverable_disabled", False)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>' return f"<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>"
@cached_slot_property('_cs_channel') @cached_slot_property("_cs_channel")
def channel(self) -> Optional[StageChannel]: def channel(self) -> Optional[StageChannel]:
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in.""" """Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
# the returned channel will always be a StageChannel or None # the returned channel will always be a StageChannel or None
return self._state.get_channel(self.channel_id) # type: ignore return self._state.get_channel(self.channel_id) # type: ignore
def is_public(self) -> bool: def is_public(self) -> bool:
return self.privacy_level is StagePrivacyLevel.public return self.privacy_level is StagePrivacyLevel.public
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None: async def edit(
self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None
) -> None:
"""|coro| """|coro|
Edits the stage instance. Edits the stage instance.
@ -142,13 +146,13 @@ class StageInstance(Hashable):
payload = {} payload = {}
if topic is not MISSING: if topic is not MISSING:
payload['topic'] = topic payload["topic"] = topic
if privacy_level is not MISSING: if privacy_level is not MISSING:
if not isinstance(privacy_level, StagePrivacyLevel): if not isinstance(privacy_level, StagePrivacyLevel):
raise InvalidArgument('privacy_level field must be of type PrivacyLevel') raise InvalidArgument("privacy_level field must be of type PrivacyLevel")
payload['privacy_level'] = privacy_level.value payload["privacy_level"] = privacy_level.value
if payload: if payload:
await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason) await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason)

File diff suppressed because it is too large Load Diff

View File

@ -33,11 +33,11 @@ from .errors import InvalidData
from .enums import StickerType, StickerFormatType, try_enum from .enums import StickerType, StickerFormatType, try_enum
__all__ = ( __all__ = (
'StickerPack', "StickerPack",
'StickerItem', "StickerItem",
'Sticker', "Sticker",
'StandardSticker', "StandardSticker",
'GuildSticker', "GuildSticker",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -67,6 +67,14 @@ class StickerPack(Hashable):
Returns the name of the sticker pack. Returns the name of the sticker pack.
.. describe:: hash(x)
Returns the hash of the sticker pack.
.. describe:: int(x)
Returns the ID of the sticker pack.
.. describe:: x == y .. describe:: x == y
Checks if the sticker pack is equal to another sticker pack. Checks if the sticker pack is equal to another sticker pack.
@ -94,15 +102,15 @@ class StickerPack(Hashable):
""" """
__slots__ = ( __slots__ = (
'_state', "_state",
'id', "id",
'stickers', "stickers",
'name', "name",
'sku_id', "sku_id",
'cover_sticker_id', "cover_sticker_id",
'cover_sticker', "cover_sticker",
'description', "description",
'_banner', "_banner",
) )
def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None: def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None:
@ -110,15 +118,17 @@ class StickerPack(Hashable):
self._from_data(data) self._from_data(data)
def _from_data(self, data: StickerPackPayload) -> None: def _from_data(self, data: StickerPackPayload) -> None:
self.id: int = int(data['id']) self.id: int = int(data["id"])
stickers = data['stickers'] stickers = data["stickers"]
self.stickers: List[StandardSticker] = [StandardSticker(state=self._state, data=sticker) for sticker in stickers] self.stickers: List[StandardSticker] = [
self.name: str = data['name'] StandardSticker(state=self._state, data=sticker) for sticker in stickers
self.sku_id: int = int(data['sku_id']) ]
self.cover_sticker_id: int = int(data['cover_sticker_id']) self.name: str = data["name"]
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore self.sku_id: int = int(data["sku_id"])
self.description: str = data['description'] self.cover_sticker_id: int = int(data["cover_sticker_id"])
self._banner: int = int(data['banner_asset_id']) self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
self.description: str = data["description"]
self._banner: int = int(data["banner_asset_id"])
@property @property
def banner(self) -> Asset: def banner(self) -> Asset:
@ -126,7 +136,7 @@ class StickerPack(Hashable):
return Asset._from_sticker_banner(self._state, self._banner) return Asset._from_sticker_banner(self._state, self._banner)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<StickerPack id={self.id} name={self.name!r} description={self.description!r}>' return f"<StickerPack id={self.id} name={self.name!r} description={self.description!r}>"
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
@ -197,17 +207,17 @@ class StickerItem(_StickerTag):
The URL for the sticker's image. The URL for the sticker's image.
""" """
__slots__ = ('_state', 'name', 'id', 'format', 'url') __slots__ = ("_state", "name", "id", "format", "url")
def __init__(self, *, state: ConnectionState, data: StickerItemPayload): def __init__(self, *, state: ConnectionState, data: StickerItemPayload):
self._state: ConnectionState = state self._state: ConnectionState = state
self.name: str = data['name'] self.name: str = data["name"]
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type']) self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}' self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<StickerItem id={self.id} name={self.name!r} format={self.format}>' return f"<StickerItem id={self.id} name={self.name!r} format={self.format}>"
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
@ -228,7 +238,7 @@ class StickerItem(_StickerTag):
The retrieved sticker. The retrieved sticker.
""" """
data: StickerPayload = await self._state.http.get_sticker(self.id) data: StickerPayload = await self._state.http.get_sticker(self.id)
cls, _ = _sticker_factory(data['type']) # type: ignore cls, _ = _sticker_factory(data["type"]) # type: ignore
return cls(state=self._state, data=data) return cls(state=self._state, data=data)
@ -267,21 +277,21 @@ class Sticker(_StickerTag):
The URL for the sticker's image. The URL for the sticker's image.
""" """
__slots__ = ('_state', 'id', 'name', 'description', 'format', 'url') __slots__ = ("_state", "id", "name", "description", "format", "url")
def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None: def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None:
self._state: ConnectionState = state self._state: ConnectionState = state
self._from_data(data) self._from_data(data)
def _from_data(self, data: StickerPayload) -> None: def _from_data(self, data: StickerPayload) -> None:
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.name: str = data['name'] self.name: str = data["name"]
self.description: str = data['description'] self.description: str = data["description"]
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type']) self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}' self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Sticker id={self.id} name={self.name!r}>' return f"<Sticker id={self.id} name={self.name!r}>"
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
@ -329,21 +339,21 @@ class StandardSticker(Sticker):
The sticker's sort order within its pack. The sticker's sort order within its pack.
""" """
__slots__ = ('sort_value', 'pack_id', 'type', 'tags') __slots__ = ("sort_value", "pack_id", "type", "tags")
def _from_data(self, data: StandardStickerPayload) -> None: def _from_data(self, data: StandardStickerPayload) -> None:
super()._from_data(data) super()._from_data(data)
self.sort_value: int = data['sort_value'] self.sort_value: int = data["sort_value"]
self.pack_id: int = int(data['pack_id']) self.pack_id: int = int(data["pack_id"])
self.type: StickerType = StickerType.standard self.type: StickerType = StickerType.standard
try: try:
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')] self.tags: List[str] = [tag.strip() for tag in data["tags"].split(",")]
except KeyError: except KeyError:
self.tags = [] self.tags = []
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>' return f"<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>"
async def pack(self) -> StickerPack: async def pack(self) -> StickerPack:
"""|coro| """|coro|
@ -363,12 +373,12 @@ class StandardSticker(Sticker):
The retrieved sticker pack. The retrieved sticker pack.
""" """
data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs() data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs()
packs = data['sticker_packs'] packs = data["sticker_packs"]
pack = find(lambda d: int(d['id']) == self.pack_id, packs) pack = find(lambda d: int(d["id"]) == self.pack_id, packs)
if pack: if pack:
return StickerPack(state=self._state, data=pack) return StickerPack(state=self._state, data=pack)
raise InvalidData(f'Could not find corresponding sticker pack for {self!r}') raise InvalidData(f"Could not find corresponding sticker pack for {self!r}")
class GuildSticker(Sticker): class GuildSticker(Sticker):
@ -411,21 +421,21 @@ class GuildSticker(Sticker):
The name of a unicode emoji that represents this sticker. The name of a unicode emoji that represents this sticker.
""" """
__slots__ = ('available', 'guild_id', 'user', 'emoji', 'type', '_cs_guild') __slots__ = ("available", "guild_id", "user", "emoji", "type", "_cs_guild")
def _from_data(self, data: GuildStickerPayload) -> None: def _from_data(self, data: GuildStickerPayload) -> None:
super()._from_data(data) super()._from_data(data)
self.available: bool = data['available'] self.available: bool = data["available"]
self.guild_id: int = int(data['guild_id']) self.guild_id: int = int(data["guild_id"])
user = data.get('user') user = data.get("user")
self.user: Optional[User] = self._state.store_user(user) if user else None self.user: Optional[User] = self._state.store_user(user) if user else None
self.emoji: str = data['tags'] self.emoji: str = data["tags"]
self.type: StickerType = StickerType.guild self.type: StickerType = StickerType.guild
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>' return f"<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>"
@cached_slot_property('_cs_guild') @cached_slot_property("_cs_guild")
def guild(self) -> Optional[Guild]: def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that this sticker is from. """Optional[:class:`Guild`]: The guild that this sticker is from.
Could be ``None`` if the bot is not in the guild. Could be ``None`` if the bot is not in the guild.
@ -472,10 +482,10 @@ class GuildSticker(Sticker):
payload: EditGuildSticker = {} payload: EditGuildSticker = {}
if name is not MISSING: if name is not MISSING:
payload['name'] = name payload["name"] = name
if description is not MISSING: if description is not MISSING:
payload['description'] = description payload["description"] = description
if emoji is not MISSING: if emoji is not MISSING:
try: try:
@ -483,9 +493,9 @@ class GuildSticker(Sticker):
except TypeError: except TypeError:
pass pass
else: else:
emoji = emoji.replace(' ', '_') emoji = emoji.replace(" ", "_")
payload['tags'] = emoji payload["tags"] = emoji
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason) data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason)
return GuildSticker(state=self._state, data=data) return GuildSticker(state=self._state, data=data)
@ -513,7 +523,9 @@ class GuildSticker(Sticker):
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason) await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason)
def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]: def _sticker_factory(
sticker_type: Literal[1, 2]
) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
value = try_enum(StickerType, sticker_type) value = try_enum(StickerType, sticker_type)
if value == StickerType.standard: if value == StickerType.standard:
return StandardSticker, value return StandardSticker, value

View File

@ -40,8 +40,8 @@ if TYPE_CHECKING:
) )
__all__ = ( __all__ = (
'Team', "Team",
'TeamMember', "TeamMember",
) )
@ -62,26 +62,26 @@ class Team:
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
__slots__ = ('_state', 'id', 'name', '_icon', 'owner_id', 'members') __slots__ = ("_state", "id", "name", "_icon", "owner_id", "members")
def __init__(self, state: ConnectionState, data: TeamPayload): def __init__(self, state: ConnectionState, data: TeamPayload):
self._state: ConnectionState = state self._state: ConnectionState = state
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.name: str = data['name'] self.name: str = data["name"]
self._icon: Optional[str] = data['icon'] self._icon: Optional[str] = data["icon"]
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_user_id') self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_user_id")
self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data['members']] self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data["members"]]
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name}>' return f"<{self.__class__.__name__} id={self.id} name={self.name}>"
@property @property
def icon(self) -> Optional[Asset]: def icon(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the team's icon asset, if any.""" """Optional[:class:`.Asset`]: Retrieves the team's icon asset, if any."""
if self._icon is None: if self._icon is None:
return None return None
return Asset._from_icon(self._state, self.id, self._icon, path='team') return Asset._from_icon(self._state, self.id, self._icon, path="team")
@property @property
def owner(self) -> Optional[TeamMember]: def owner(self) -> Optional[TeamMember]:
@ -130,16 +130,16 @@ class TeamMember(BaseUser):
The membership state of the member (e.g. invited or accepted) The membership state of the member (e.g. invited or accepted)
""" """
__slots__ = ('team', 'membership_state', 'permissions') __slots__ = ("team", "membership_state", "permissions")
def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload): def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload):
self.team: Team = team self.team: Team = team
self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data['membership_state']) self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data["membership_state"])
self.permissions: List[str] = data['permissions'] self.permissions: List[str] = data["permissions"]
super().__init__(state=state, data=data['user']) super().__init__(state=state, data=data["user"])
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} ' f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
f'discriminator={self.discriminator!r} membership_state={self.membership_state!r}>' f"discriminator={self.discriminator!r} membership_state={self.membership_state!r}>"
) )

View File

@ -29,9 +29,7 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
from .enums import VoiceRegion from .enums import VoiceRegion
from .guild import Guild from .guild import Guild
__all__ = ( __all__ = ("Template",)
'Template',
)
if TYPE_CHECKING: if TYPE_CHECKING:
import datetime import datetime
@ -44,7 +42,7 @@ class _FriendlyHttpAttributeErrorHelper:
__slots__ = () __slots__ = ()
def __getattr__(self, attr): def __getattr__(self, attr):
raise AttributeError('PartialTemplateState does not support http methods.') raise AttributeError("PartialTemplateState does not support http methods.")
class _PartialTemplateState: class _PartialTemplateState:
@ -84,7 +82,7 @@ class _PartialTemplateState:
return [] return []
def __getattr__(self, attr): def __getattr__(self, attr):
raise AttributeError(f'PartialTemplateState does not support {attr!r}.') raise AttributeError(f"PartialTemplateState does not support {attr!r}.")
class Template: class Template:
@ -118,16 +116,16 @@ class Template:
""" """
__slots__ = ( __slots__ = (
'code', "code",
'uses', "uses",
'name', "name",
'description', "description",
'creator', "creator",
'created_at', "created_at",
'updated_at', "updated_at",
'source_guild', "source_guild",
'is_dirty', "is_dirty",
'_state', "_state",
) )
def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None: def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None:
@ -135,35 +133,35 @@ class Template:
self._store(data) self._store(data)
def _store(self, data: TemplatePayload) -> None: def _store(self, data: TemplatePayload) -> None:
self.code: str = data['code'] self.code: str = data["code"]
self.uses: int = data['usage_count'] self.uses: int = data["usage_count"]
self.name: str = data['name'] self.name: str = data["name"]
self.description: Optional[str] = data['description'] self.description: Optional[str] = data["description"]
creator_data = data.get('creator') creator_data = data.get("creator")
self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data) self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data)
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at')) self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at"))
self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at')) self.updated_at: Optional[datetime.datetime] = parse_time(data.get("updated_at"))
guild_id = int(data['source_guild_id']) guild_id = int(data["source_guild_id"])
guild: Optional[Guild] = self._state._get_guild(guild_id) guild: Optional[Guild] = self._state._get_guild(guild_id)
self.source_guild: Guild self.source_guild: Guild
if guild is None: if guild is None:
source_serialised = data['serialized_source_guild'] source_serialised = data["serialized_source_guild"]
source_serialised['id'] = guild_id source_serialised["id"] = guild_id
state = _PartialTemplateState(state=self._state) state = _PartialTemplateState(state=self._state)
# Guild expects a ConnectionState, we're passing a _PartialTemplateState # Guild expects a ConnectionState, we're passing a _PartialTemplateState
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
else: else:
self.source_guild = guild self.source_guild = guild
self.is_dirty: Optional[bool] = data.get('is_dirty', None) self.is_dirty: Optional[bool] = data.get("is_dirty", None)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<Template code={self.code!r} uses={self.uses} name={self.name!r}' f"<Template code={self.code!r} uses={self.uses} name={self.name!r}"
f' creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>' f" creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>"
) )
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild: async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild:
@ -279,9 +277,9 @@ class Template:
payload = {} payload = {}
if name is not MISSING: if name is not MISSING:
payload['name'] = name payload["name"] = name
if description is not MISSING: if description is not MISSING:
payload['description'] = description payload["description"] = description
data = await self._state.http.edit_template(self.source_guild.id, self.code, payload) data = await self._state.http.edit_template(self.source_guild.id, self.code, payload)
return Template(state=self._state, data=data) return Template(state=self._state, data=data)
@ -313,4 +311,4 @@ class Template:
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
return f'https://discord.new/{self.code}' return f"https://discord.new/{self.code}"

View File

@ -35,8 +35,8 @@ from .errors import ClientException
from .utils import MISSING, parse_time, _get_as_snowflake from .utils import MISSING, parse_time, _get_as_snowflake
__all__ = ( __all__ = (
'Thread', "Thread",
'ThreadMember', "ThreadMember",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -74,6 +74,10 @@ class Thread(Messageable, Hashable):
Returns the thread's hash. Returns the thread's hash.
.. describe:: int(x)
Returns the thread's ID.
.. describe:: str(x) .. describe:: str(x)
Returns the thread's name. Returns the thread's name.
@ -124,25 +128,25 @@ class Thread(Messageable, Hashable):
""" """
__slots__ = ( __slots__ = (
'name', "name",
'id', "id",
'guild', "guild",
'_type', "_type",
'_state', "_state",
'_members', "_members",
'owner_id', "owner_id",
'parent_id', "parent_id",
'last_message_id', "last_message_id",
'message_count', "message_count",
'member_count', "member_count",
'slowmode_delay', "slowmode_delay",
'me', "me",
'locked', "locked",
'archived', "archived",
'invitable', "invitable",
'archiver_id', "archiver_id",
'auto_archive_duration', "auto_archive_duration",
'archive_timestamp', "archive_timestamp",
) )
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload): def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
@ -156,50 +160,50 @@ class Thread(Messageable, Hashable):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<Thread id={self.id!r} name={self.name!r} parent={self.parent}' f"<Thread id={self.id!r} name={self.name!r} parent={self.parent}"
f' owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>' f" owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>"
) )
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
def _from_data(self, data: ThreadPayload): def _from_data(self, data: ThreadPayload):
self.id = int(data['id']) self.id = int(data["id"])
self.parent_id = int(data['parent_id']) self.parent_id = int(data["parent_id"])
self.owner_id = int(data['owner_id']) self.owner_id = int(data["owner_id"])
self.name = data['name'] self.name = data["name"]
self._type = try_enum(ChannelType, data['type']) self._type = try_enum(ChannelType, data["type"])
self.last_message_id = _get_as_snowflake(data, 'last_message_id') self.last_message_id = _get_as_snowflake(data, "last_message_id")
self.slowmode_delay = data.get('rate_limit_per_user', 0) self.slowmode_delay = data.get("rate_limit_per_user", 0)
self.message_count = data['message_count'] self.message_count = data["message_count"]
self.member_count = data['member_count'] self.member_count = data["member_count"]
self._unroll_metadata(data['thread_metadata']) self._unroll_metadata(data["thread_metadata"])
try: try:
member = data['member'] member = data["member"]
except KeyError: except KeyError:
self.me = None self.me = None
else: else:
self.me = ThreadMember(self, member) self.me = ThreadMember(self, member)
def _unroll_metadata(self, data: ThreadMetadata): def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data['archived'] self.archived = data["archived"]
self.archiver_id = _get_as_snowflake(data, 'archiver_id') self.archiver_id = _get_as_snowflake(data, "archiver_id")
self.auto_archive_duration = data['auto_archive_duration'] self.auto_archive_duration = data["auto_archive_duration"]
self.archive_timestamp = parse_time(data['archive_timestamp']) self.archive_timestamp = parse_time(data["archive_timestamp"])
self.locked = data.get('locked', False) self.locked = data.get("locked", False)
self.invitable = data.get('invitable', True) self.invitable = data.get("invitable", True)
def _update(self, data): def _update(self, data):
try: try:
self.name = data['name'] self.name = data["name"]
except KeyError: except KeyError:
pass pass
self.slowmode_delay = data.get('rate_limit_per_user', 0) self.slowmode_delay = data.get("rate_limit_per_user", 0)
try: try:
self._unroll_metadata(data['thread_metadata']) self._unroll_metadata(data["thread_metadata"])
except KeyError: except KeyError:
pass pass
@ -221,7 +225,7 @@ class Thread(Messageable, Hashable):
@property @property
def mention(self) -> str: def mention(self) -> str:
""":class:`str`: The string that allows you to mention the thread.""" """:class:`str`: The string that allows you to mention the thread."""
return f'<#{self.id}>' return f"<#{self.id}>"
@property @property
def members(self) -> List[ThreadMember]: def members(self) -> List[ThreadMember]:
@ -271,7 +275,7 @@ class Thread(Messageable, Hashable):
parent = self.parent parent = self.parent
if parent is None: if parent is None:
raise ClientException('Parent channel not found') raise ClientException("Parent channel not found")
return parent.category return parent.category
@property @property
@ -291,7 +295,7 @@ class Thread(Messageable, Hashable):
parent = self.parent parent = self.parent
if parent is None: if parent is None:
raise ClientException('Parent channel not found') raise ClientException("Parent channel not found")
return parent.category_id return parent.category_id
def is_private(self) -> bool: def is_private(self) -> bool:
@ -348,7 +352,7 @@ class Thread(Messageable, Hashable):
parent = self.parent parent = self.parent
if parent is None: if parent is None:
raise ClientException('Parent channel not found') raise ClientException("Parent channel not found")
return parent.permissions_for(obj) return parent.permissions_for(obj)
async def delete_messages(self, messages: Iterable[Snowflake]) -> None: async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
@ -398,7 +402,7 @@ class Thread(Messageable, Hashable):
return return
if len(messages) > 100: if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages') raise ClientException("Can only bulk delete messages up to 100 messages")
message_ids: SnowflakeList = [m.id for m in messages] message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids) await self._state.http.delete_messages(self.id, message_ids)
@ -573,17 +577,17 @@ class Thread(Messageable, Hashable):
""" """
payload = {} payload = {}
if name is not MISSING: if name is not MISSING:
payload['name'] = str(name) payload["name"] = str(name)
if archived is not MISSING: if archived is not MISSING:
payload['archived'] = archived payload["archived"] = archived
if auto_archive_duration is not MISSING: if auto_archive_duration is not MISSING:
payload['auto_archive_duration'] = auto_archive_duration payload["auto_archive_duration"] = auto_archive_duration
if locked is not MISSING: if locked is not MISSING:
payload['locked'] = locked payload["locked"] = locked
if invitable is not MISSING: if invitable is not MISSING:
payload['invitable'] = invitable payload["invitable"] = invitable
if slowmode_delay is not MISSING: if slowmode_delay is not MISSING:
payload['rate_limit_per_user'] = slowmode_delay payload["rate_limit_per_user"] = slowmode_delay
data = await self._state.http.edit_channel(self.id, **payload) data = await self._state.http.edit_channel(self.id, **payload)
# The data payload will always be a Thread payload # The data payload will always be a Thread payload
@ -748,6 +752,10 @@ class ThreadMember(Hashable):
Returns the thread member's hash. Returns the thread member's hash.
.. describe:: int(x)
Returns the thread member's ID.
.. describe:: str(x) .. describe:: str(x)
Returns the thread member's name. Returns the thread member's name.
@ -765,12 +773,12 @@ class ThreadMember(Hashable):
""" """
__slots__ = ( __slots__ = (
'id', "id",
'thread_id', "thread_id",
'joined_at', "joined_at",
'flags', "flags",
'_state', "_state",
'parent', "parent",
) )
def __init__(self, parent: Thread, data: ThreadMemberPayload): def __init__(self, parent: Thread, data: ThreadMemberPayload):
@ -779,24 +787,60 @@ class ThreadMember(Hashable):
self._from_data(data) self._from_data(data)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>' return f"<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>"
def _from_data(self, data: ThreadMemberPayload): def _from_data(self, data: ThreadMemberPayload):
try: try:
self.id = int(data['user_id']) self.id = int(data["user_id"])
except KeyError: except KeyError:
assert self._state.self_id is not None assert self._state.self_id is not None
self.id = self._state.self_id self.id = self._state.self_id
try: try:
self.thread_id = int(data['id']) self.thread_id = int(data["id"])
except KeyError: except KeyError:
self.thread_id = self.parent.id self.thread_id = self.parent.id
self.joined_at = parse_time(data['join_timestamp']) self.joined_at = parse_time(data["join_timestamp"])
self.flags = data['flags'] self.flags = data["flags"]
@property @property
def thread(self) -> Thread: def thread(self) -> Thread:
""":class:`Thread`: The thread this member belongs to.""" """:class:`Thread`: The thread this member belongs to."""
return self.parent return self.parent
async def fetch_member(self) -> Member:
"""|coro|
Retrieves a :class:`Member` from the ThreadMember object.
.. note::
This method is an API call. If you have :attr:`Intents.members` and member cache enabled, consider :meth:`get_member` instead.
Raises
-------
Forbidden
You do not have access to the guild.
HTTPException
Fetching the member failed.
Returns
--------
:class:`Member`
The member.
"""
return await self.thread.guild.fetch_member(self.id)
def get_member(self) -> Optional[Member]:
"""
Get the :class:`Member` from cache for the ThreadMember object.
Returns
--------
Optional[:class:`Member`]
The member or ``None`` if not found.
"""
return self.thread.guild.get_member(self.id)

View File

@ -29,7 +29,7 @@ from .user import PartialUser
from .snowflake import Snowflake from .snowflake import Snowflake
StatusType = Literal['idle', 'dnd', 'online', 'offline'] StatusType = Literal["idle", "dnd", "online", "offline"]
class PartialPresenceUpdate(TypedDict): class PartialPresenceUpdate(TypedDict):

View File

@ -30,6 +30,7 @@ from .user import User
from .team import Team from .team import Team
from .snowflake import Snowflake from .snowflake import Snowflake
class BaseAppInfo(TypedDict): class BaseAppInfo(TypedDict):
id: Snowflake id: Snowflake
name: str name: str
@ -38,6 +39,7 @@ class BaseAppInfo(TypedDict):
summary: str summary: str
description: str description: str
class _AppInfoOptional(TypedDict, total=False): class _AppInfoOptional(TypedDict, total=False):
team: Team team: Team
guild_id: Snowflake guild_id: Snowflake
@ -48,12 +50,14 @@ class _AppInfoOptional(TypedDict, total=False):
hook: bool hook: bool
max_participants: int max_participants: int
class AppInfo(BaseAppInfo, _AppInfoOptional): class AppInfo(BaseAppInfo, _AppInfoOptional):
rpc_origins: List[str] rpc_origins: List[str]
owner: User owner: User
bot_public: bool bot_public: bool
bot_require_code_grant: bool bot_require_code_grant: bool
class _PartialAppInfoOptional(TypedDict, total=False): class _PartialAppInfoOptional(TypedDict, total=False):
rpc_origins: List[str] rpc_origins: List[str]
cover_image: str cover_image: str
@ -63,5 +67,6 @@ class _PartialAppInfoOptional(TypedDict, total=False):
max_participants: int max_participants: int
flags: int flags: int
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo): class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
pass pass

View File

@ -84,31 +84,40 @@ AuditLogEvent = Literal[
class _AuditLogChange_Str(TypedDict): class _AuditLogChange_Str(TypedDict):
key: Literal[ key: Literal[
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions', 'tags' "name",
"description",
"preferred_locale",
"vanity_url_code",
"topic",
"code",
"allow",
"deny",
"permissions",
"tags",
] ]
new_value: str new_value: str
old_value: str old_value: str
class _AuditLogChange_AssetHash(TypedDict): class _AuditLogChange_AssetHash(TypedDict):
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash', 'asset'] key: Literal["icon_hash", "splash_hash", "discovery_splash_hash", "banner_hash", "avatar_hash", "asset"]
new_value: str new_value: str
old_value: str old_value: str
class _AuditLogChange_Snowflake(TypedDict): class _AuditLogChange_Snowflake(TypedDict):
key: Literal[ key: Literal[
'id', "id",
'owner_id', "owner_id",
'afk_channel_id', "afk_channel_id",
'rules_channel_id', "rules_channel_id",
'public_updates_channel_id', "public_updates_channel_id",
'widget_channel_id', "widget_channel_id",
'system_channel_id', "system_channel_id",
'application_id', "application_id",
'channel_id', "channel_id",
'inviter_id', "inviter_id",
'guild_id', "guild_id",
] ]
new_value: Snowflake new_value: Snowflake
old_value: Snowflake old_value: Snowflake
@ -116,20 +125,20 @@ class _AuditLogChange_Snowflake(TypedDict):
class _AuditLogChange_Bool(TypedDict): class _AuditLogChange_Bool(TypedDict):
key: Literal[ key: Literal[
'widget_enabled', "widget_enabled",
'nsfw', "nsfw",
'hoist', "hoist",
'mentionable', "mentionable",
'temporary', "temporary",
'deaf', "deaf",
'mute', "mute",
'nick', "nick",
'enabled_emoticons', "enabled_emoticons",
'region', "region",
'rtc_region', "rtc_region",
'available', "available",
'archived', "archived",
'locked', "locked",
] ]
new_value: bool new_value: bool
old_value: bool old_value: bool
@ -137,72 +146,72 @@ class _AuditLogChange_Bool(TypedDict):
class _AuditLogChange_Int(TypedDict): class _AuditLogChange_Int(TypedDict):
key: Literal[ key: Literal[
'afk_timeout', "afk_timeout",
'prune_delete_days', "prune_delete_days",
'position', "position",
'bitrate', "bitrate",
'rate_limit_per_user', "rate_limit_per_user",
'color', "color",
'max_uses', "max_uses",
'max_age', "max_age",
'user_limit', "user_limit",
'auto_archive_duration', "auto_archive_duration",
'default_auto_archive_duration', "default_auto_archive_duration",
] ]
new_value: int new_value: int
old_value: int old_value: int
class _AuditLogChange_ListRole(TypedDict): class _AuditLogChange_ListRole(TypedDict):
key: Literal['$add', '$remove'] key: Literal["$add", "$remove"]
new_value: List[Role] new_value: List[Role]
old_value: List[Role] old_value: List[Role]
class _AuditLogChange_MFALevel(TypedDict): class _AuditLogChange_MFALevel(TypedDict):
key: Literal['mfa_level'] key: Literal["mfa_level"]
new_value: MFALevel new_value: MFALevel
old_value: MFALevel old_value: MFALevel
class _AuditLogChange_VerificationLevel(TypedDict): class _AuditLogChange_VerificationLevel(TypedDict):
key: Literal['verification_level'] key: Literal["verification_level"]
new_value: VerificationLevel new_value: VerificationLevel
old_value: VerificationLevel old_value: VerificationLevel
class _AuditLogChange_ExplicitContentFilter(TypedDict): class _AuditLogChange_ExplicitContentFilter(TypedDict):
key: Literal['explicit_content_filter'] key: Literal["explicit_content_filter"]
new_value: ExplicitContentFilterLevel new_value: ExplicitContentFilterLevel
old_value: ExplicitContentFilterLevel old_value: ExplicitContentFilterLevel
class _AuditLogChange_DefaultMessageNotificationLevel(TypedDict): class _AuditLogChange_DefaultMessageNotificationLevel(TypedDict):
key: Literal['default_message_notifications'] key: Literal["default_message_notifications"]
new_value: DefaultMessageNotificationLevel new_value: DefaultMessageNotificationLevel
old_value: DefaultMessageNotificationLevel old_value: DefaultMessageNotificationLevel
class _AuditLogChange_ChannelType(TypedDict): class _AuditLogChange_ChannelType(TypedDict):
key: Literal['type'] key: Literal["type"]
new_value: ChannelType new_value: ChannelType
old_value: ChannelType old_value: ChannelType
class _AuditLogChange_IntegrationExpireBehaviour(TypedDict): class _AuditLogChange_IntegrationExpireBehaviour(TypedDict):
key: Literal['expire_behavior'] key: Literal["expire_behavior"]
new_value: IntegrationExpireBehavior new_value: IntegrationExpireBehavior
old_value: IntegrationExpireBehavior old_value: IntegrationExpireBehavior
class _AuditLogChange_VideoQualityMode(TypedDict): class _AuditLogChange_VideoQualityMode(TypedDict):
key: Literal['video_quality_mode'] key: Literal["video_quality_mode"]
new_value: VideoQualityMode new_value: VideoQualityMode
old_value: VideoQualityMode old_value: VideoQualityMode
class _AuditLogChange_Overwrites(TypedDict): class _AuditLogChange_Overwrites(TypedDict):
key: Literal['permission_overwrites'] key: Literal["permission_overwrites"]
new_value: List[PermissionOverwrite] new_value: List[PermissionOverwrite]
old_value: List[PermissionOverwrite] old_value: List[PermissionOverwrite]
@ -232,7 +241,7 @@ class AuditEntryInfo(TypedDict):
message_id: Snowflake message_id: Snowflake
count: str count: str
id: Snowflake id: Snowflake
type: Literal['0', '1'] type: Literal["0", "1"]
role_name: str role_name: str

View File

@ -24,49 +24,60 @@ DEALINGS IN THE SOFTWARE.
from typing import List, Literal, TypedDict from typing import List, Literal, TypedDict
class _EmbedFooterOptional(TypedDict, total=False): class _EmbedFooterOptional(TypedDict, total=False):
icon_url: str icon_url: str
proxy_icon_url: str proxy_icon_url: str
class EmbedFooter(_EmbedFooterOptional): class EmbedFooter(_EmbedFooterOptional):
text: str text: str
class _EmbedFieldOptional(TypedDict, total=False): class _EmbedFieldOptional(TypedDict, total=False):
inline: bool inline: bool
class EmbedField(_EmbedFieldOptional): class EmbedField(_EmbedFieldOptional):
name: str name: str
value: str value: str
class EmbedThumbnail(TypedDict, total=False): class EmbedThumbnail(TypedDict, total=False):
url: str url: str
proxy_url: str proxy_url: str
height: int height: int
width: int width: int
class EmbedVideo(TypedDict, total=False): class EmbedVideo(TypedDict, total=False):
url: str url: str
proxy_url: str proxy_url: str
height: int height: int
width: int width: int
class EmbedImage(TypedDict, total=False): class EmbedImage(TypedDict, total=False):
url: str url: str
proxy_url: str proxy_url: str
height: int height: int
width: int width: int
class EmbedProvider(TypedDict, total=False): class EmbedProvider(TypedDict, total=False):
name: str name: str
url: str url: str
class EmbedAuthor(TypedDict, total=False): class EmbedAuthor(TypedDict, total=False):
name: str name: str
url: str url: str
icon_url: str icon_url: str
proxy_icon_url: str proxy_icon_url: str
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link']
EmbedType = Literal["rich", "image", "video", "gifv", "article", "link"]
class Embed(TypedDict, total=False): class Embed(TypedDict, total=False):
title: str title: str

View File

@ -75,28 +75,28 @@ VerificationLevel = Literal[0, 1, 2, 3, 4]
NSFWLevel = Literal[0, 1, 2, 3] NSFWLevel = Literal[0, 1, 2, 3]
PremiumTier = Literal[0, 1, 2, 3] PremiumTier = Literal[0, 1, 2, 3]
GuildFeature = Literal[ GuildFeature = Literal[
'ANIMATED_ICON', "ANIMATED_ICON",
'BANNER', "BANNER",
'COMMERCE', "COMMERCE",
'COMMUNITY', "COMMUNITY",
'DISCOVERABLE', "DISCOVERABLE",
'FEATURABLE', "FEATURABLE",
'INVITE_SPLASH', "INVITE_SPLASH",
'MEMBER_VERIFICATION_GATE_ENABLED', "MEMBER_VERIFICATION_GATE_ENABLED",
'MONETIZATION_ENABLED', "MONETIZATION_ENABLED",
'MORE_EMOJI', "MORE_EMOJI",
'MORE_STICKERS', "MORE_STICKERS",
'NEWS', "NEWS",
'PARTNERED', "PARTNERED",
'PREVIEW_ENABLED', "PREVIEW_ENABLED",
'PRIVATE_THREADS', "PRIVATE_THREADS",
'SEVEN_DAY_THREAD_ARCHIVE', "SEVEN_DAY_THREAD_ARCHIVE",
'THREE_DAY_THREAD_ARCHIVE', "THREE_DAY_THREAD_ARCHIVE",
'TICKETED_EVENTS_ENABLED', "TICKETED_EVENTS_ENABLED",
'VANITY_URL', "VANITY_URL",
'VERIFIED', "VERIFIED",
'VIP_REGIONS', "VIP_REGIONS",
'WELCOME_SCREEN_ENABLED', "WELCOME_SCREEN_ENABLED",
] ]

View File

@ -56,7 +56,7 @@ class PartialIntegration(TypedDict):
account: IntegrationAccount account: IntegrationAccount
IntegrationType = Literal['twitch', 'youtube', 'discord'] IntegrationType = Literal["twitch", "youtube", "discord"]
class BaseIntegration(PartialIntegration): class BaseIntegration(PartialIntegration):

View File

@ -39,6 +39,7 @@ if TYPE_CHECKING:
ApplicationCommandType = Literal[1, 2, 3] ApplicationCommandType = Literal[1, 2, 3]
class _ApplicationCommandOptional(TypedDict, total=False): class _ApplicationCommandOptional(TypedDict, total=False):
options: List[ApplicationCommandOption] options: List[ApplicationCommandOption]
type: ApplicationCommandType type: ApplicationCommandType
@ -222,15 +223,12 @@ class MessageInteraction(TypedDict):
user: User user: User
class _EditApplicationCommandOptional(TypedDict, total=False): class _EditApplicationCommandOptional(TypedDict, total=False):
description: str description: str
options: Optional[List[ApplicationCommandOption]] options: Optional[List[ApplicationCommandOption]]
type: ApplicationCommandType type: ApplicationCommandType
default_permission: bool
class EditApplicationCommand(_EditApplicationCommandOptional): class EditApplicationCommand(_EditApplicationCommandOptional):
name: str name: str
default_permission: bool

View File

@ -53,6 +53,7 @@ class _AttachmentOptional(TypedDict, total=False):
height: Optional[int] height: Optional[int]
width: Optional[int] width: Optional[int]
content_type: str content_type: str
ephemeral: bool
spoiler: bool spoiler: bool
@ -128,7 +129,7 @@ class Message(_MessageOptional):
type: MessageType type: MessageType
AllowedMentionType = Literal['roles', 'users', 'everyone'] AllowedMentionType = Literal["roles", "users", "everyone"]
class AllowedMentions(TypedDict): class AllowedMentions(TypedDict):

View File

@ -85,3 +85,14 @@ class _IntegrationDeleteEventOptional(TypedDict, total=False):
class IntegrationDeleteEvent(_IntegrationDeleteEventOptional): class IntegrationDeleteEvent(_IntegrationDeleteEventOptional):
id: Snowflake id: Snowflake
guild_id: Snowflake guild_id: Snowflake
class _TypingEventOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
class TypingEvent(_TypingEventOptional):
channel_id: Snowflake
user_id: Snowflake
timestamp: int

View File

@ -29,12 +29,14 @@ from typing import TypedDict, List, Optional
from .user import PartialUser from .user import PartialUser
from .snowflake import Snowflake from .snowflake import Snowflake
class TeamMember(TypedDict): class TeamMember(TypedDict):
user: PartialUser user: PartialUser
membership_state: int membership_state: int
permissions: List[str] permissions: List[str]
team_id: Snowflake team_id: Snowflake
class Team(TypedDict): class Team(TypedDict):
id: Snowflake id: Snowflake
name: str name: str

View File

@ -27,7 +27,7 @@ from .snowflake import Snowflake
from .member import MemberWithUser from .member import MemberWithUser
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305'] SupportedModes = Literal["xsalsa20_poly1305_lite", "xsalsa20_poly1305_suffix", "xsalsa20_poly1305"]
class _PartialVoiceStateOptional(TypedDict, total=False): class _PartialVoiceStateOptional(TypedDict, total=False):

View File

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
from typing import Callable, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union from typing import Any, Callable, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
import inspect import inspect
import os import os
@ -35,16 +35,16 @@ from ..partial_emoji import PartialEmoji, _EmojiTag
from ..components import Button as ButtonComponent from ..components import Button as ButtonComponent
__all__ = ( __all__ = (
'Button', "Button",
'button', "button",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .view import View from .view import View
from ..emoji import Emoji from ..emoji import Emoji
B = TypeVar('B', bound='Button') B = TypeVar("B", bound="Button")
V = TypeVar('V', bound='View', covariant=True) V = TypeVar("V", bound="View", covariant=True)
class Button(Item[V]): class Button(Item[V]):
@ -60,7 +60,7 @@ class Button(Item[V]):
The ID of the button that gets received during an interaction. The ID of the button that gets received during an interaction.
If this button is for a URL, it does not have a custom ID. If this button is for a URL, it does not have a custom ID.
url: Optional[:class:`str`] url: Optional[:class:`str`]
The URL this button sends you to. The URL this button sends you to. This param is automatically casted to :class:`str`.
disabled: :class:`bool` disabled: :class:`bool`
Whether the button is disabled or not. Whether the button is disabled or not.
label: Optional[:class:`str`] label: Optional[:class:`str`]
@ -76,12 +76,12 @@ class Button(Item[V]):
""" """
__item_repr_attributes__: Tuple[str, ...] = ( __item_repr_attributes__: Tuple[str, ...] = (
'style', "style",
'url', "url",
'disabled', "disabled",
'label', "label",
'emoji', "emoji",
'row', "row",
) )
def __init__( def __init__(
@ -91,13 +91,13 @@ class Button(Item[V]):
label: Optional[str] = None, label: Optional[str] = None,
disabled: bool = False, disabled: bool = False,
custom_id: Optional[str] = None, custom_id: Optional[str] = None,
url: Optional[str] = None, url: Optional[Any] = None,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
row: Optional[int] = None, row: Optional[int] = None,
): ):
super().__init__() super().__init__()
if custom_id is not None and url is not None: if custom_id is not None and url is not None:
raise TypeError('cannot mix both url and custom_id with Button') raise TypeError("cannot mix both url and custom_id with Button")
self._provided_custom_id = custom_id is not None self._provided_custom_id = custom_id is not None
if url is None and custom_id is None: if url is None and custom_id is None:
@ -112,12 +112,12 @@ class Button(Item[V]):
elif isinstance(emoji, _EmojiTag): elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial() emoji = emoji._to_partial()
else: else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') raise TypeError(f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}")
self._underlying = ButtonComponent._raw_construct( self._underlying = ButtonComponent._raw_construct(
type=ComponentType.button, type=ComponentType.button,
custom_id=custom_id, custom_id=custom_id,
url=url, url=str(url) if url else None,
disabled=disabled, disabled=disabled,
label=label, label=label,
style=style, style=style,
@ -145,7 +145,7 @@ class Button(Item[V]):
@custom_id.setter @custom_id.setter
def custom_id(self, value: Optional[str]): def custom_id(self, value: Optional[str]):
if value is not None and not isinstance(value, str): if value is not None and not isinstance(value, str):
raise TypeError('custom_id must be None or str') raise TypeError("custom_id must be None or str")
self._underlying.custom_id = value self._underlying.custom_id = value
@ -157,7 +157,7 @@ class Button(Item[V]):
@url.setter @url.setter
def url(self, value: Optional[str]): def url(self, value: Optional[str]):
if value is not None and not isinstance(value, str): if value is not None and not isinstance(value, str):
raise TypeError('url must be None or str') raise TypeError("url must be None or str")
self._underlying.url = value self._underlying.url = value
@property @property
@ -191,7 +191,7 @@ class Button(Item[V]):
elif isinstance(value, _EmojiTag): elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial() self._underlying.emoji = value._to_partial()
else: else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead') raise TypeError(f"expected str, Emoji, or PartialEmoji, received {value.__class__} instead")
else: else:
self._underlying.emoji = None self._underlying.emoji = None
@ -273,17 +273,17 @@ def button(
def decorator(func: ItemCallbackType) -> ItemCallbackType: def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func): if not inspect.iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function') raise TypeError("button function must be a coroutine function")
func.__discord_ui_model_type__ = Button func.__discord_ui_model_type__ = Button
func.__discord_ui_model_kwargs__ = { func.__discord_ui_model_kwargs__ = {
'style': style, "style": style,
'custom_id': custom_id, "custom_id": custom_id,
'url': None, "url": None,
'disabled': disabled, "disabled": disabled,
'label': label, "label": label,
'emoji': emoji, "emoji": emoji,
'row': row, "row": row,
} }
return func return func

View File

@ -28,17 +28,15 @@ from typing import Any, Callable, Coroutine, Dict, Generic, Optional, TYPE_CHECK
from ..interactions import Interaction from ..interactions import Interaction
__all__ = ( __all__ = ("Item",)
'Item',
)
if TYPE_CHECKING: if TYPE_CHECKING:
from ..enums import ComponentType from ..enums import ComponentType
from .view import View from .view import View
from ..components import Component from ..components import Component
I = TypeVar('I', bound='Item') I = TypeVar("I", bound="Item")
V = TypeVar('V', bound='View', covariant=True) V = TypeVar("V", bound="View", covariant=True)
ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]] ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]]
@ -53,7 +51,7 @@ class Item(Generic[V]):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
__item_repr_attributes__: Tuple[str, ...] = ('row',) __item_repr_attributes__: Tuple[str, ...] = ("row",)
def __init__(self): def __init__(self):
self._view: Optional[V] = None self._view: Optional[V] = None
@ -91,8 +89,8 @@ class Item(Generic[V]):
return self._provided_custom_id return self._provided_custom_id
def __repr__(self) -> str: def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__item_repr_attributes__) attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__item_repr_attributes__)
return f'<{self.__class__.__name__} {attrs}>' return f"<{self.__class__.__name__} {attrs}>"
@property @property
def row(self) -> Optional[int]: def row(self) -> Optional[int]:
@ -105,7 +103,7 @@ class Item(Generic[V]):
elif 5 > value >= 0: elif 5 > value >= 0:
self._row = value self._row = value
else: else:
raise ValueError('row cannot be negative or greater than or equal to 5') raise ValueError("row cannot be negative or greater than or equal to 5")
@property @property
def width(self) -> int: def width(self) -> int:

View File

@ -39,8 +39,8 @@ from ..components import (
) )
__all__ = ( __all__ = (
'Select', "Select",
'select', "select",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -50,8 +50,8 @@ if TYPE_CHECKING:
ComponentInteractionData, ComponentInteractionData,
) )
S = TypeVar('S', bound='Select') S = TypeVar("S", bound="Select")
V = TypeVar('V', bound='View', covariant=True) V = TypeVar("V", bound="View", covariant=True)
class Select(Item[V]): class Select(Item[V]):
@ -72,7 +72,7 @@ class Select(Item[V]):
The placeholder text that is shown if nothing is selected, if any. The placeholder text that is shown if nothing is selected, if any.
min_values: :class:`int` min_values: :class:`int`
The minimum number of items that must be chosen for this select menu. The minimum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25. Defaults to 1 and must be between 0 and 25.
max_values: :class:`int` max_values: :class:`int`
The maximum number of items that must be chosen for this select menu. The maximum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25. Defaults to 1 and must be between 1 and 25.
@ -89,11 +89,11 @@ class Select(Item[V]):
""" """
__item_repr_attributes__: Tuple[str, ...] = ( __item_repr_attributes__: Tuple[str, ...] = (
'placeholder', "placeholder",
'min_values', "min_values",
'max_values', "max_values",
'options', "options",
'disabled', "disabled",
) )
def __init__( def __init__(
@ -131,7 +131,7 @@ class Select(Item[V]):
@custom_id.setter @custom_id.setter
def custom_id(self, value: str): def custom_id(self, value: str):
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError('custom_id must be None or str') raise TypeError("custom_id must be None or str")
self._underlying.custom_id = value self._underlying.custom_id = value
@ -143,7 +143,7 @@ class Select(Item[V]):
@placeholder.setter @placeholder.setter
def placeholder(self, value: Optional[str]): def placeholder(self, value: Optional[str]):
if value is not None and not isinstance(value, str): if value is not None and not isinstance(value, str):
raise TypeError('placeholder must be None or str') raise TypeError("placeholder must be None or str")
self._underlying.placeholder = value self._underlying.placeholder = value
@ -173,9 +173,9 @@ class Select(Item[V]):
@options.setter @options.setter
def options(self, value: List[SelectOption]): def options(self, value: List[SelectOption]):
if not isinstance(value, list): if not isinstance(value, list):
raise TypeError('options must be a list of SelectOption') raise TypeError("options must be a list of SelectOption")
if not all(isinstance(obj, SelectOption) for obj in value): if not all(isinstance(obj, SelectOption) for obj in value):
raise TypeError('all list items must subclass SelectOption') raise TypeError("all list items must subclass SelectOption")
self._underlying.options = value self._underlying.options = value
@ -224,7 +224,6 @@ class Select(Item[V]):
default=default, default=default,
) )
self.append_option(option) self.append_option(option)
def append_option(self, option: SelectOption): def append_option(self, option: SelectOption):
@ -242,7 +241,7 @@ class Select(Item[V]):
""" """
if len(self._underlying.options) > 25: if len(self._underlying.options) > 25:
raise ValueError('maximum number of options already provided') raise ValueError("maximum number of options already provided")
self._underlying.options.append(option) self._underlying.options.append(option)
@ -272,7 +271,7 @@ class Select(Item[V]):
def refresh_state(self, interaction: Interaction) -> None: def refresh_state(self, interaction: Interaction) -> None:
data: ComponentInteractionData = interaction.data # type: ignore data: ComponentInteractionData = interaction.data # type: ignore
self._selected_values = data.get('values', []) self._selected_values = data.get("values", [])
@classmethod @classmethod
def from_component(cls: Type[S], component: SelectMenu) -> S: def from_component(cls: Type[S], component: SelectMenu) -> S:
@ -328,7 +327,7 @@ def select(
ordering. The row number must be between 0 and 4 (i.e. zero indexed). ordering. The row number must be between 0 and 4 (i.e. zero indexed).
min_values: :class:`int` min_values: :class:`int`
The minimum number of items that must be chosen for this select menu. The minimum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25. Defaults to 1 and must be between 0 and 25.
max_values: :class:`int` max_values: :class:`int`
The maximum number of items that must be chosen for this select menu. The maximum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25. Defaults to 1 and must be between 1 and 25.
@ -340,17 +339,17 @@ def select(
def decorator(func: ItemCallbackType) -> ItemCallbackType: def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func): if not inspect.iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function') raise TypeError("select function must be a coroutine function")
func.__discord_ui_model_type__ = Select func.__discord_ui_model_type__ = Select
func.__discord_ui_model_kwargs__ = { func.__discord_ui_model_kwargs__ = {
'placeholder': placeholder, "placeholder": placeholder,
'custom_id': custom_id, "custom_id": custom_id,
'row': row, "row": row,
'min_values': min_values, "min_values": min_values,
'max_values': max_values, "max_values": max_values,
'options': options, "options": options,
'disabled': disabled, "disabled": disabled,
} }
return func return func

View File

@ -41,9 +41,7 @@ from ..components import (
SelectMenu as SelectComponent, SelectMenu as SelectComponent,
) )
__all__ = ( __all__ = ("View",)
'View',
)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -74,9 +72,7 @@ def _component_to_item(component: Component) -> Item:
class _ViewWeights: class _ViewWeights:
__slots__ = ( __slots__ = ("weights",)
'weights',
)
def __init__(self, children: List[Item]): def __init__(self, children: List[Item]):
self.weights: List[int] = [0, 0, 0, 0, 0] self.weights: List[int] = [0, 0, 0, 0, 0]
@ -92,13 +88,13 @@ class _ViewWeights:
if weight + item.width <= 5: if weight + item.width <= 5:
return index return index
raise ValueError('could not find open space for item') raise ValueError("could not find open space for item")
def add_item(self, item: Item) -> None: def add_item(self, item: Item) -> None:
if item.row is not None: if item.row is not None:
total = self.weights[item.row] + item.width total = self.weights[item.row] + item.width
if total > 5: if total > 5:
raise ValueError(f'item would not fit at row {item.row} ({total} > 5 width)') raise ValueError(f"item would not fit at row {item.row} ({total} > 5 width)")
self.weights[item.row] = total self.weights[item.row] = total
item._rendered_row = item.row item._rendered_row = item.row
else: else:
@ -144,11 +140,11 @@ class View:
children: List[ItemCallbackType] = [] children: List[ItemCallbackType] = []
for base in reversed(cls.__mro__): for base in reversed(cls.__mro__):
for member in base.__dict__.values(): for member in base.__dict__.values():
if hasattr(member, '__discord_ui_model_type__'): if hasattr(member, "__discord_ui_model_type__"):
children.append(member) children.append(member)
if len(children) > 25: if len(children) > 25:
raise TypeError('View cannot have more than 25 children') raise TypeError("View cannot have more than 25 children")
cls.__view_children_items__ = children cls.__view_children_items__ = children
@ -171,7 +167,7 @@ class View:
self.__stopped: asyncio.Future[bool] = loop.create_future() self.__stopped: asyncio.Future[bool] = loop.create_future()
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>' return f"<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>"
async def __timeout_task_impl(self) -> None: async def __timeout_task_impl(self) -> None:
while True: while True:
@ -203,8 +199,8 @@ class View:
components.append( components.append(
{ {
'type': 1, "type": 1,
'components': children, "components": children,
} }
) )
@ -261,10 +257,10 @@ class View:
""" """
if len(self.children) > 25: if len(self.children) > 25:
raise ValueError('maximum number of children exceeded') raise ValueError("maximum number of children exceeded")
if not isinstance(item, Item): if not isinstance(item, Item):
raise TypeError(f'expected Item not {item.__class__!r}') raise TypeError(f"expected Item not {item.__class__!r}")
self.__weights.add_item(item) self.__weights.add_item(item)
@ -344,7 +340,7 @@ class View:
interaction: :class:`~discord.Interaction` interaction: :class:`~discord.Interaction`
The interaction that led to the failure. The interaction that led to the failure.
""" """
print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr) print(f"Ignoring exception in view {self} for item {item}:", file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
async def _scheduled_task(self, item: Item, interaction: Interaction): async def _scheduled_task(self, item: Item, interaction: Interaction):
@ -357,7 +353,7 @@ class View:
return return
await item.callback(interaction) await item.callback(interaction)
if not interaction.response._responded: if not interaction.response.is_done():
await interaction.response.defer() await interaction.response.defer()
except Exception as e: except Exception as e:
return await self.on_error(e, item, interaction) return await self.on_error(e, item, interaction)
@ -377,13 +373,13 @@ class View:
return return
self.__stopped.set_result(True) self.__stopped.set_result(True)
asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}') asyncio.create_task(self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}")
def _dispatch_item(self, item: Item, interaction: Interaction): def _dispatch_item(self, item: Item, interaction: Interaction):
if self.__stopped.done(): if self.__stopped.done():
return return
asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}') asyncio.create_task(self._scheduled_task(item, interaction), name=f"discord-ui-view-dispatch-{self.id}")
def refresh(self, components: List[Component]): def refresh(self, components: List[Component]):
# This is pretty hacky at the moment # This is pretty hacky at the moment

View File

@ -45,11 +45,11 @@ if TYPE_CHECKING:
__all__ = ( __all__ = (
'User', "User",
'ClientUser', "ClientUser",
) )
BU = TypeVar('BU', bound='BaseUser') BU = TypeVar("BU", bound="BaseUser")
class _UserTag: class _UserTag:
@ -59,16 +59,16 @@ class _UserTag:
class BaseUser(_UserTag): class BaseUser(_UserTag):
__slots__ = ( __slots__ = (
'name', "name",
'id', "id",
'discriminator', "discriminator",
'_avatar', "_avatar",
'_banner', "_banner",
'_accent_colour', "_accent_colour",
'bot', "bot",
'system', "system",
'_public_flags', "_public_flags",
'_state', "_state",
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -80,7 +80,7 @@ class BaseUser(_UserTag):
_state: ConnectionState _state: ConnectionState
_avatar: Optional[str] _avatar: Optional[str]
_banner: Optional[str] _banner: Optional[str]
_accent_colour: Optional[str] _accent_colour: Optional[int]
_public_flags: int _public_flags: int
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None: def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
@ -94,7 +94,10 @@ class BaseUser(_UserTag):
) )
def __str__(self) -> str: def __str__(self) -> str:
return f'{self.name}#{self.discriminator}' return f"{self.name}#{self.discriminator}"
def __int__(self) -> int:
return self.id
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return isinstance(other, _UserTag) and other.id == self.id return isinstance(other, _UserTag) and other.id == self.id
@ -106,15 +109,15 @@ class BaseUser(_UserTag):
return self.id >> 22 return self.id >> 22
def _update(self, data: UserPayload) -> None: def _update(self, data: UserPayload) -> None:
self.name = data['username'] self.name = data["username"]
self.id = int(data['id']) self.id = int(data["id"])
self.discriminator = data['discriminator'] self.discriminator = data["discriminator"]
self._avatar = data['avatar'] self._avatar = data["avatar"]
self._banner = data.get('banner', None) self._banner = data.get("banner", None)
self._accent_colour = data.get('accent_color', None) self._accent_colour = data.get("accent_color", None)
self._public_flags = data.get('public_flags', 0) self._public_flags = data.get("public_flags", 0)
self.bot = data.get('bot', False) self.bot = data.get("bot", False)
self.system = data.get('system', False) self.system = data.get("system", False)
@classmethod @classmethod
def _copy(cls: Type[BU], user: BU) -> BU: def _copy(cls: Type[BU], user: BU) -> BU:
@ -134,11 +137,11 @@ class BaseUser(_UserTag):
def _to_minimal_user_json(self) -> Dict[str, Any]: def _to_minimal_user_json(self) -> Dict[str, Any]:
return { return {
'username': self.name, "username": self.name,
'id': self.id, "id": self.id,
'avatar': self._avatar, "avatar": self._avatar,
'discriminator': self.discriminator, "discriminator": self.discriminator,
'bot': self.bot, "bot": self.bot,
} }
@property @property
@ -237,7 +240,7 @@ class BaseUser(_UserTag):
@property @property
def mention(self) -> str: def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention the given user.""" """:class:`str`: Returns a string that allows you to mention the given user."""
return f'<@{self.id}>' return f"<@{self.id}>"
@property @property
def created_at(self) -> datetime: def created_at(self) -> datetime:
@ -321,7 +324,7 @@ class ClientUser(BaseUser):
Specifies if the user has MFA turned on and working. Specifies if the user has MFA turned on and working.
""" """
__slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__') __slots__ = ("locale", "_flags", "verified", "mfa_enabled", "__weakref__")
if TYPE_CHECKING: if TYPE_CHECKING:
verified: bool verified: bool
@ -334,17 +337,17 @@ class ClientUser(BaseUser):
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}' f"<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>' f" bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>"
) )
def _update(self, data: UserPayload) -> None: def _update(self, data: UserPayload) -> None:
super()._update(data) super()._update(data)
# There's actually an Optional[str] phone field as well but I won't use it # There's actually an Optional[str] phone field as well but I won't use it
self.verified = data.get('verified', False) self.verified = data.get("verified", False)
self.locale = data.get('locale') self.locale = data.get("locale")
self._flags = data.get('flags', 0) self._flags = data.get("flags", 0)
self.mfa_enabled = data.get('mfa_enabled', False) self.mfa_enabled = data.get("mfa_enabled", False)
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser: async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser:
"""|coro| """|coro|
@ -385,10 +388,10 @@ class ClientUser(BaseUser):
""" """
payload: Dict[str, Any] = {} payload: Dict[str, Any] = {}
if username is not MISSING: if username is not MISSING:
payload['username'] = username payload["username"] = username
if avatar is not MISSING: if avatar is not MISSING:
payload['avatar'] = _bytes_to_base64_data(avatar) payload["avatar"] = _bytes_to_base64_data(avatar)
data: UserPayload = await self._state.http.edit_profile(payload) data: UserPayload = await self._state.http.edit_profile(payload)
return ClientUser(state=self._state, data=data) return ClientUser(state=self._state, data=data)
@ -415,6 +418,10 @@ class User(BaseUser, discord.abc.Messageable):
Returns the user's name with discriminator. Returns the user's name with discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes Attributes
----------- -----------
name: :class:`str` name: :class:`str`
@ -429,14 +436,14 @@ class User(BaseUser, discord.abc.Messageable):
Specifies if the user is a system user (i.e. represents Discord officially). Specifies if the user is a system user (i.e. represents Discord officially).
""" """
__slots__ = ('_stored',) __slots__ = ("_stored",)
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None: def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
super().__init__(state=state, data=data) super().__init__(state=state, data=data)
self._stored: bool = False self._stored: bool = False
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>' return f"<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>"
def __del__(self) -> None: def __del__(self) -> None:
try: try:

View File

@ -72,18 +72,18 @@ else:
__all__ = ( __all__ = (
'oauth_url', "oauth_url",
'snowflake_time', "snowflake_time",
'time_snowflake', "time_snowflake",
'find', "find",
'get', "get",
'sleep_until', "sleep_until",
'utcnow', "utcnow",
'remove_markdown', "remove_markdown",
'escape_markdown', "escape_markdown",
'escape_mentions', "escape_mentions",
'as_chunks', "as_chunks",
'format_dt', "format_dt",
) )
DISCORD_EPOCH = 1420070400000 DISCORD_EPOCH = 1420070400000
@ -97,7 +97,7 @@ class _MissingSentinel:
return False return False
def __repr__(self): def __repr__(self):
return '...' return "..."
MISSING: Any = _MissingSentinel() MISSING: Any = _MissingSentinel()
@ -106,7 +106,7 @@ MISSING: Any = _MissingSentinel()
class _cached_property: class _cached_property:
def __init__(self, function): def __init__(self, function):
self.function = function self.function = function
self.__doc__ = getattr(function, '__doc__') self.__doc__ = getattr(function, "__doc__")
def __get__(self, instance, owner): def __get__(self, instance, owner):
if instance is None: if instance is None:
@ -131,15 +131,14 @@ if TYPE_CHECKING:
class _RequestLike(Protocol): class _RequestLike(Protocol):
headers: Mapping[str, Any] headers: Mapping[str, Any]
P = ParamSpec("P")
P = ParamSpec('P')
else: else:
cached_property = _cached_property cached_property = _cached_property
T = TypeVar('T') T = TypeVar("T")
T_co = TypeVar('T_co', covariant=True) T_co = TypeVar("T_co", covariant=True)
_Iter = Union[Iterator[T], AsyncIterator[T]] _Iter = Union[Iterator[T], AsyncIterator[T]]
@ -147,7 +146,7 @@ class CachedSlotProperty(Generic[T, T_co]):
def __init__(self, name: str, function: Callable[[T], T_co]) -> None: def __init__(self, name: str, function: Callable[[T], T_co]) -> None:
self.name = name self.name = name
self.function = function self.function = function
self.__doc__ = getattr(function, '__doc__') self.__doc__ = getattr(function, "__doc__")
@overload @overload
def __get__(self, instance: None, owner: Type[T]) -> CachedSlotProperty[T, T_co]: def __get__(self, instance: None, owner: Type[T]) -> CachedSlotProperty[T, T_co]:
@ -177,7 +176,7 @@ class classproperty(Generic[T_co]):
return self.fget(owner) return self.fget(owner)
def __set__(self, instance, value) -> None: def __set__(self, instance, value) -> None:
raise AttributeError('cannot set attribute') raise AttributeError("cannot set attribute")
def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]: def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]:
@ -249,14 +248,14 @@ def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Call
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]: def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func) @functools.wraps(func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> T: def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
warnings.simplefilter('always', DeprecationWarning) # turn off filter warnings.simplefilter("always", DeprecationWarning) # turn off filter
if instead: if instead:
fmt = "{0.__name__} is deprecated, use {1} instead." fmt = "{0.__name__} is deprecated, use {1} instead."
else: else:
fmt = '{0.__name__} is deprecated.' fmt = "{0.__name__} is deprecated."
warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning) warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning)
warnings.simplefilter('default', DeprecationWarning) # reset filter warnings.simplefilter("default", DeprecationWarning) # reset filter
return func(*args, **kwargs) return func(*args, **kwargs)
return decorated return decorated
@ -301,18 +300,18 @@ def oauth_url(
:class:`str` :class:`str`
The OAuth2 URL for inviting the bot into guilds. The OAuth2 URL for inviting the bot into guilds.
""" """
url = f'https://discord.com/oauth2/authorize?client_id={client_id}' url = f"https://discord.com/oauth2/authorize?client_id={client_id}"
url += '&scope=' + '+'.join(scopes or ('bot',)) url += "&scope=" + "+".join(scopes or ("bot",))
if permissions is not MISSING: if permissions is not MISSING:
url += f'&permissions={permissions.value}' url += f"&permissions={permissions.value}"
if guild is not MISSING: if guild is not MISSING:
url += f'&guild_id={guild.id}' url += f"&guild_id={guild.id}"
if redirect_uri is not MISSING: if redirect_uri is not MISSING:
from urllib.parse import urlencode from urllib.parse import urlencode
url += '&response_type=code&' + urlencode({'redirect_uri': redirect_uri}) url += "&response_type=code&" + urlencode({"redirect_uri": redirect_uri})
if disable_guild_select: if disable_guild_select:
url += '&disable_guild_select=true' url += "&disable_guild_select=true"
return url return url
@ -435,13 +434,13 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
# Special case the single element call # Special case the single element call
if len(attrs) == 1: if len(attrs) == 1:
k, v = attrs.popitem() k, v = attrs.popitem()
pred = attrget(k.replace('__', '.')) pred = attrget(k.replace("__", "."))
for elem in iterable: for elem in iterable:
if pred(elem) == v: if pred(elem) == v:
return elem return elem
return None return None
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()] converted = [(attrget(attr.replace("__", ".")), value) for attr, value in attrs.items()]
for elem in iterable: for elem in iterable:
if _all(pred(elem) == value for pred, value in converted): if _all(pred(elem) == value for pred, value in converted):
@ -463,46 +462,46 @@ def _get_as_snowflake(data: Any, key: str) -> Optional[int]:
def _get_mime_type_for_image(data: bytes): def _get_mime_type_for_image(data: bytes):
if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'): if data.startswith(b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"):
return 'image/png' return "image/png"
elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'): elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"):
return 'image/jpeg' return "image/jpeg"
elif data.startswith((b'\x47\x49\x46\x38\x37\x61', b'\x47\x49\x46\x38\x39\x61')): elif data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")):
return 'image/gif' return "image/gif"
elif data.startswith(b'RIFF') and data[8:12] == b'WEBP': elif data.startswith(b"RIFF") and data[8:12] == b"WEBP":
return 'image/webp' return "image/webp"
else: else:
raise InvalidArgument('Unsupported image type given') raise InvalidArgument("Unsupported image type given")
def _bytes_to_base64_data(data: bytes) -> str: def _bytes_to_base64_data(data: bytes) -> str:
fmt = 'data:{mime};base64,{data}' fmt = "data:{mime};base64,{data}"
mime = _get_mime_type_for_image(data) mime = _get_mime_type_for_image(data)
b64 = b64encode(data).decode('ascii') b64 = b64encode(data).decode("ascii")
return fmt.format(mime=mime, data=b64) return fmt.format(mime=mime, data=b64)
if HAS_ORJSON: if HAS_ORJSON:
def _to_json(obj: Any) -> str: # type: ignore def _to_json(obj: Any) -> str: # type: ignore
return orjson.dumps(obj).decode('utf-8') return orjson.dumps(obj).decode("utf-8")
_from_json = orjson.loads # type: ignore _from_json = orjson.loads # type: ignore
else: else:
def _to_json(obj: Any) -> str: def _to_json(obj: Any) -> str:
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True) return json.dumps(obj, separators=(",", ":"), ensure_ascii=True)
_from_json = json.loads _from_json = json.loads
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After') reset_after: Optional[str] = request.headers.get("X-Ratelimit-Reset-After")
if use_clock or not reset_after: if use_clock or not reset_after:
utc = datetime.timezone.utc utc = datetime.timezone.utc
now = datetime.datetime.now(utc) now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc) reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc)
return (reset - now).total_seconds() return (reset - now).total_seconds()
else: else:
return float(reset_after) return float(reset_after)
@ -612,7 +611,7 @@ class SnowflakeList(array.array):
... ...
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False): def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore
def add(self, element: int) -> None: def add(self, element: int) -> None:
i = bisect_left(self, element) i = bisect_left(self, element)
@ -627,7 +626,7 @@ class SnowflakeList(array.array):
return i != len(self) and self[i] == element return i != len(self) and self[i] == element
_IS_ASCII = re.compile(r'^[\x00-\x7f]+$') _IS_ASCII = re.compile(r"^[\x00-\x7f]+$")
def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int: def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
@ -636,7 +635,7 @@ def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
if match: if match:
return match.endpos return match.endpos
UNICODE_WIDE_CHAR_TYPE = 'WFA' UNICODE_WIDE_CHAR_TYPE = "WFA"
func = unicodedata.east_asian_width func = unicodedata.east_asian_width
return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string) return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string)
@ -660,7 +659,7 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
if isinstance(invite, Invite): if isinstance(invite, Invite):
return invite.code return invite.code
else: else:
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' rx = r"(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)"
m = re.match(rx, invite) m = re.match(rx, invite)
if m: if m:
return m.group(1) return m.group(1)
@ -688,22 +687,24 @@ def resolve_template(code: Union[Template, str]) -> str:
if isinstance(code, Template): if isinstance(code, Template):
return code.code return code.code
else: else:
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' rx = r"(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)"
m = re.match(rx, code) m = re.match(rx, code)
if m: if m:
return m.group(1) return m.group(1)
return code return code
_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) for c in ('*', '`', '_', '~', '|')) _MARKDOWN_ESCAPE_SUBREGEX = "|".join(r"\{0}(?=([\s\S]*((?<!\{0})\{0})))".format(c) for c in ("*", "`", "_", "~", "|"))
_MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)' _MARKDOWN_ESCAPE_COMMON = r"^>(?:>>)?\s|\[.+\]\(.+\)"
_MARKDOWN_ESCAPE_REGEX = re.compile(fr'(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})', re.MULTILINE) _MARKDOWN_ESCAPE_REGEX = re.compile(
fr"(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})", re.MULTILINE
)
_URL_REGEX = r'(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])' _URL_REGEX = r"(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])"
_MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})' _MARKDOWN_STOCK_REGEX = fr"(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})"
def remove_markdown(text: str, *, ignore_links: bool = True) -> str: def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
@ -732,11 +733,11 @@ def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
def replacement(match): def replacement(match):
groupdict = match.groupdict() groupdict = match.groupdict()
return groupdict.get('url', '') return groupdict.get("url", "")
regex = _MARKDOWN_STOCK_REGEX regex = _MARKDOWN_STOCK_REGEX
if ignore_links: if ignore_links:
regex = f'(?:{_URL_REGEX}|{regex})' regex = f"(?:{_URL_REGEX}|{regex})"
return re.sub(regex, replacement, text, 0, re.MULTILINE) return re.sub(regex, replacement, text, 0, re.MULTILINE)
@ -769,18 +770,18 @@ def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool =
def replacement(match): def replacement(match):
groupdict = match.groupdict() groupdict = match.groupdict()
is_url = groupdict.get('url') is_url = groupdict.get("url")
if is_url: if is_url:
return is_url return is_url
return '\\' + groupdict['markdown'] return "\\" + groupdict["markdown"]
regex = _MARKDOWN_STOCK_REGEX regex = _MARKDOWN_STOCK_REGEX
if ignore_links: if ignore_links:
regex = f'(?:{_URL_REGEX}|{regex})' regex = f"(?:{_URL_REGEX}|{regex})"
return re.sub(regex, replacement, text, 0, re.MULTILINE) return re.sub(regex, replacement, text, 0, re.MULTILINE)
else: else:
text = re.sub(r'\\', r'\\\\', text) text = re.sub(r"\\", r"\\\\", text)
return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text) return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text)
def escape_mentions(text: str) -> str: def escape_mentions(text: str) -> str:
@ -806,7 +807,7 @@ def escape_mentions(text: str) -> str:
:class:`str` :class:`str`
The text with the mentions removed. The text with the mentions removed.
""" """
return re.sub(r'@(everyone|here|[!&]?[0-9]{17,20})', '@\u200b\\1', text) return re.sub(r"@(everyone|here|[!&]?[0-9]{17,20})", "@\u200b\\1", text)
def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]: def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
@ -870,7 +871,7 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
A new iterator which yields chunks of a given size. A new iterator which yields chunks of a given size.
""" """
if max_size <= 0: if max_size <= 0:
raise ValueError('Chunk sizes must be greater than 0.') raise ValueError("Chunk sizes must be greater than 0.")
if isinstance(iterator, AsyncIterator): if isinstance(iterator, AsyncIterator):
return _achunk(iterator, max_size) return _achunk(iterator, max_size)
@ -916,11 +917,11 @@ def evaluate_annotation(
cache[tp] = evaluated cache[tp] = evaluated
return evaluate_annotation(evaluated, globals, locals, cache) return evaluate_annotation(evaluated, globals, locals, cache)
if hasattr(tp, '__args__'): if hasattr(tp, "__args__"):
implicit_str = True implicit_str = True
is_literal = False is_literal = False
args = tp.__args__ args = tp.__args__
if not hasattr(tp, '__origin__'): if not hasattr(tp, "__origin__"):
if PY_310 and tp.__class__ is types.UnionType: # type: ignore if PY_310 and tp.__class__ is types.UnionType: # type: ignore
converted = Union[args] # type: ignore converted = Union[args] # type: ignore
return evaluate_annotation(converted, globals, locals, cache) return evaluate_annotation(converted, globals, locals, cache)
@ -938,10 +939,12 @@ def evaluate_annotation(
implicit_str = False implicit_str = False
is_literal = True is_literal = True
evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args) evaluated_args = tuple(
evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
)
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
raise TypeError('Literal arguments must be of type str, int, bool, or NoneType.') raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.")
if evaluated_args == args: if evaluated_args == args:
return tp return tp
@ -971,7 +974,7 @@ def resolve_annotation(
return evaluate_annotation(annotation, globalns, locals, cache) return evaluate_annotation(annotation, globalns, locals, cache)
TimestampStyle = Literal['f', 'F', 'd', 'D', 't', 'T', 'R'] TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]
def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) -> str: def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) -> str:
@ -1015,5 +1018,5 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
The formatted string. The formatted string.
""" """
if style is None: if style is None:
return f'<t:{int(dt.timestamp())}>' return f"<t:{int(dt.timestamp())}>"
return f'<t:{int(dt.timestamp())}:{style}>' return f"<t:{int(dt.timestamp())}:{style}>"

View File

@ -72,20 +72,20 @@ has_nacl: bool
try: try:
import nacl.secret # type: ignore import nacl.secret # type: ignore
has_nacl = True has_nacl = True
except ImportError: except ImportError:
has_nacl = False has_nacl = False
__all__ = ( __all__ = (
'VoiceProtocol', "VoiceProtocol",
'VoiceClient', "VoiceClient",
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class VoiceProtocol: class VoiceProtocol:
"""A class that represents the Discord voice protocol. """A class that represents the Discord voice protocol.
@ -195,6 +195,7 @@ class VoiceProtocol:
key_id, _ = self.channel._get_voice_client_key() key_id, _ = self.channel._get_voice_client_key()
self.client._connection._remove_voice_client(key_id) self.client._connection._remove_voice_client(key_id)
class VoiceClient(VoiceProtocol): class VoiceClient(VoiceProtocol):
"""Represents a Discord voice connection. """Represents a Discord voice connection.
@ -221,12 +222,12 @@ class VoiceClient(VoiceProtocol):
loop: :class:`asyncio.AbstractEventLoop` loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on. The event loop that the voice client is running on.
""" """
endpoint_ip: str endpoint_ip: str
voice_port: int voice_port: int
secret_key: List[int] secret_key: List[int]
ssrc: int ssrc: int
def __init__(self, client: Client, channel: abc.Connectable): def __init__(self, client: Client, channel: abc.Connectable):
if not has_nacl: if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice") raise RuntimeError("PyNaCl library needed in order to use voice")
@ -255,18 +256,20 @@ class VoiceClient(VoiceProtocol):
self.encoder: Encoder = MISSING self.encoder: Encoder = MISSING
self._lite_nonce: int = 0 self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING self.ws: DiscordVoiceWebSocket = MISSING
self.ip: str = MISSING
self.port: Tuple[Any, ...] = MISSING
warn_nacl = not has_nacl warn_nacl = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = ( supported_modes: Tuple[SupportedModes, ...] = (
'xsalsa20_poly1305_lite', "xsalsa20_poly1305_lite",
'xsalsa20_poly1305_suffix', "xsalsa20_poly1305_suffix",
'xsalsa20_poly1305', "xsalsa20_poly1305",
) )
@property @property
def guild(self) -> Optional[Guild]: def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild we're connected to, if applicable.""" """Optional[:class:`Guild`]: The guild we're connected to, if applicable."""
return getattr(self.channel, 'guild', None) return getattr(self.channel, "guild", None)
@property @property
def user(self) -> ClientUser: def user(self) -> ClientUser:
@ -283,8 +286,8 @@ class VoiceClient(VoiceProtocol):
# connection related # connection related
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data['session_id'] self.session_id = data["session_id"]
channel_id = data['channel_id'] channel_id = data["channel_id"]
if not self._handshaking or self._potentially_reconnecting: if not self._handshaking or self._potentially_reconnecting:
# If we're done handshaking then we just need to update ourselves # If we're done handshaking then we just need to update ourselves
@ -301,20 +304,22 @@ class VoiceClient(VoiceProtocol):
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
if self._voice_server_complete.is_set(): if self._voice_server_complete.is_set():
_log.info('Ignoring extraneous voice server update.') _log.info("Ignoring extraneous voice server update.")
return return
self.token = data.get('token') self.token = data.get("token")
self.server_id = int(data['guild_id']) self.server_id = int(data["guild_id"])
endpoint = data.get('endpoint') endpoint = data.get("endpoint")
if endpoint is None or self.token is None: if endpoint is None or self.token is None:
_log.warning('Awaiting endpoint... This requires waiting. ' \ _log.warning(
'If timeout occurred considering raising the timeout and reconnecting.') "Awaiting endpoint... This requires waiting. "
"If timeout occurred considering raising the timeout and reconnecting."
)
return return
self.endpoint, _, _ = endpoint.rpartition(':') self.endpoint, _, _ = endpoint.rpartition(":")
if self.endpoint.startswith('wss://'): if self.endpoint.startswith("wss://"):
# Just in case, strip it off since we're going to add it later # Just in case, strip it off since we're going to add it later
self.endpoint = self.endpoint[6:] self.endpoint = self.endpoint[6:]
@ -335,18 +340,20 @@ class VoiceClient(VoiceProtocol):
await self.channel.guild.change_voice_state(channel=self.channel) await self.channel.guild.change_voice_state(channel=self.channel)
async def voice_disconnect(self) -> None: async def voice_disconnect(self) -> None:
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) _log.info(
"The voice handshake is being terminated for Channel ID %s (Guild ID %s)", self.channel.id, self.guild.id
)
await self.channel.guild.change_voice_state(channel=None) await self.channel.guild.change_voice_state(channel=None)
def prepare_handshake(self) -> None: def prepare_handshake(self) -> None:
self._voice_state_complete.clear() self._voice_state_complete.clear()
self._voice_server_complete.clear() self._voice_server_complete.clear()
self._handshaking = True self._handshaking = True
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) _log.info("Starting voice handshake... (connection attempt %d)", self._connections + 1)
self._connections += 1 self._connections += 1
def finish_handshake(self) -> None: def finish_handshake(self) -> None:
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint) _log.info("Voice handshake complete. Endpoint found %s", self.endpoint)
self._handshaking = False self._handshaking = False
self._voice_server_complete.clear() self._voice_server_complete.clear()
self._voice_state_complete.clear() self._voice_state_complete.clear()
@ -359,8 +366,8 @@ class VoiceClient(VoiceProtocol):
self._connected.set() self._connected.set()
return ws return ws
async def connect(self, *, reconnect: bool, timeout: float) ->None: async def connect(self, *, reconnect: bool, timeout: float) -> None:
_log.info('Connecting to voice...') _log.info("Connecting to voice...")
self.timeout = timeout self.timeout = timeout
for i in range(5): for i in range(5):
@ -388,7 +395,7 @@ class VoiceClient(VoiceProtocol):
break break
except (ConnectionClosed, asyncio.TimeoutError): except (ConnectionClosed, asyncio.TimeoutError):
if reconnect: if reconnect:
_log.exception('Failed to connect to voice... Retrying...') _log.exception("Failed to connect to voice... Retrying...")
await asyncio.sleep(1 + i * 2.0) await asyncio.sleep(1 + i * 2.0)
await self.voice_disconnect() await self.voice_disconnect()
continue continue
@ -453,14 +460,14 @@ class VoiceClient(VoiceProtocol):
# 4014 - voice channel has been deleted. # 4014 - voice channel has been deleted.
# 4015 - voice server has crashed # 4015 - voice server has crashed
if exc.code in (1000, 4015): if exc.code in (1000, 4015):
_log.info('Disconnecting from voice normally, close code %d.', exc.code) _log.info("Disconnecting from voice normally, close code %d.", exc.code)
await self.disconnect() await self.disconnect()
break break
if exc.code == 4014: if exc.code == 4014:
_log.info('Disconnected from voice by force... potentially reconnecting.') _log.info("Disconnected from voice by force... potentially reconnecting.")
successful = await self.potential_reconnect() successful = await self.potential_reconnect()
if not successful: if not successful:
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...') _log.info("Reconnect was unsuccessful, disconnecting from voice normally...")
await self.disconnect() await self.disconnect()
break break
else: else:
@ -471,7 +478,7 @@ class VoiceClient(VoiceProtocol):
raise raise
retry = backoff.delay() retry = backoff.delay()
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) _log.exception("Disconnected from voice... Reconnecting in %.2fs.", retry)
self._connected.clear() self._connected.clear()
await asyncio.sleep(retry) await asyncio.sleep(retry)
await self.voice_disconnect() await self.voice_disconnect()
@ -479,7 +486,7 @@ class VoiceClient(VoiceProtocol):
await self.connect(reconnect=True, timeout=self.timeout) await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop. # at this point we've retried 5 times... let's continue the loop.
_log.warning('Could not connect to voice... Retrying...') _log.warning("Could not connect to voice... Retrying...")
continue continue
async def disconnect(self, *, force: bool = False) -> None: async def disconnect(self, *, force: bool = False) -> None:
@ -527,11 +534,11 @@ class VoiceClient(VoiceProtocol):
# Formulate rtp header # Formulate rtp header
header[0] = 0x80 header[0] = 0x80
header[1] = 0x78 header[1] = 0x78
struct.pack_into('>H', header, 2, self.sequence) struct.pack_into(">H", header, 2, self.sequence)
struct.pack_into('>I', header, 4, self.timestamp) struct.pack_into(">I", header, 4, self.timestamp)
struct.pack_into('>I', header, 8, self.ssrc) struct.pack_into(">I", header, 8, self.ssrc)
encrypt_packet = getattr(self, '_encrypt_' + self.mode) encrypt_packet = getattr(self, "_encrypt_" + self.mode)
return encrypt_packet(header, data) return encrypt_packet(header, data)
def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes: def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes:
@ -551,12 +558,12 @@ class VoiceClient(VoiceProtocol):
box = nacl.secret.SecretBox(bytes(self.secret_key)) box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24) nonce = bytearray(24)
nonce[:4] = struct.pack('>I', self._lite_nonce) nonce[:4] = struct.pack(">I", self._lite_nonce)
self.checked_add('_lite_nonce', 1, 4294967295) self.checked_add("_lite_nonce", 1, 4294967295)
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4] return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any]=None) -> None: def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any] = None) -> None:
"""Plays an :class:`AudioSource`. """Plays an :class:`AudioSource`.
The finalizer, ``after`` is called after the source has been exhausted The finalizer, ``after`` is called after the source has been exhausted
@ -586,13 +593,13 @@ class VoiceClient(VoiceProtocol):
""" """
if not self.is_connected(): if not self.is_connected():
raise ClientException('Not connected to voice.') raise ClientException("Not connected to voice.")
if self.is_playing(): if self.is_playing():
raise ClientException('Already playing audio.') raise ClientException("Already playing audio.")
if not isinstance(source, AudioSource): if not isinstance(source, AudioSource):
raise TypeError(f'source must be an AudioSource not {source.__class__.__name__}') raise TypeError(f"source must be an AudioSource not {source.__class__.__name__}")
if not self.encoder and not source.is_opus(): if not self.encoder and not source.is_opus():
self.encoder = opus.Encoder() self.encoder = opus.Encoder()
@ -635,10 +642,10 @@ class VoiceClient(VoiceProtocol):
@source.setter @source.setter
def source(self, value: AudioSource) -> None: def source(self, value: AudioSource) -> None:
if not isinstance(value, AudioSource): if not isinstance(value, AudioSource):
raise TypeError(f'expected AudioSource not {value.__class__.__name__}.') raise TypeError(f"expected AudioSource not {value.__class__.__name__}.")
if self._player is None: if self._player is None:
raise ValueError('Not playing anything.') raise ValueError("Not playing anything.")
self._player._set_source(value) self._player._set_source(value)
@ -662,7 +669,7 @@ class VoiceClient(VoiceProtocol):
Encoding the data failed. Encoding the data failed.
""" """
self.checked_add('sequence', 1, 65535) self.checked_add("sequence", 1, 65535)
if encode: if encode:
encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME) encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME)
else: else:
@ -671,6 +678,6 @@ class VoiceClient(VoiceProtocol):
try: try:
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
except BlockingIOError: except BlockingIOError:
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) _log.warning("A packet has been dropped (seq: %s, timestamp: %s)", self.sequence, self.timestamp)
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295) self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295)

View File

@ -46,10 +46,10 @@ from ..mixins import Hashable
from ..channel import PartialMessageable from ..channel import PartialMessageable
__all__ = ( __all__ = (
'Webhook', "Webhook",
'WebhookMessage', "WebhookMessage",
'PartialWebhookChannel', "PartialWebhookChannel",
'PartialWebhookGuild', "PartialWebhookGuild",
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -120,14 +120,14 @@ class AsyncWebhookAdapter:
self._locks[bucket] = lock = asyncio.Lock() self._locks[bucket] = lock = asyncio.Lock()
if payload is not None: if payload is not None:
headers['Content-Type'] = 'application/json' headers["Content-Type"] = "application/json"
to_send = utils._to_json(payload) to_send = utils._to_json(payload)
if auth_token is not None: if auth_token is not None:
headers['Authorization'] = f'Bot {auth_token}' headers["Authorization"] = f"Bot {auth_token}"
if reason is not None: if reason is not None:
headers['X-Audit-Log-Reason'] = urlquote(reason, safe='/ ') headers["X-Audit-Log-Reason"] = urlquote(reason, safe="/ ")
response: Optional[aiohttp.ClientResponse] = None response: Optional[aiohttp.ClientResponse] = None
data: Optional[Union[Dict[str, Any], str]] = None data: Optional[Union[Dict[str, Any], str]] = None
@ -149,21 +149,23 @@ class AsyncWebhookAdapter:
try: try:
async with session.request(method, url, data=to_send, headers=headers, params=params) as response: async with session.request(method, url, data=to_send, headers=headers, params=params) as response:
_log.debug( _log.debug(
'Webhook ID %s with %s %s has returned status code %s', "Webhook ID %s with %s %s has returned status code %s",
webhook_id, webhook_id,
method, method,
url, url,
response.status, response.status,
) )
data = (await response.text(encoding='utf-8')) or None data = (await response.text(encoding="utf-8")) or None
if data and response.headers['Content-Type'] == 'application/json': if data and response.headers["Content-Type"] == "application/json":
data = json.loads(data) data = json.loads(data)
remaining = response.headers.get('X-Ratelimit-Remaining') remaining = response.headers.get("X-Ratelimit-Remaining")
if remaining == '0' and response.status != 429: if remaining == "0" and response.status != 429:
delta = utils._parse_ratelimit_header(response) delta = utils._parse_ratelimit_header(response)
_log.debug( _log.debug(
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta "Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds",
webhook_id,
delta,
) )
lock.delay_by(delta) lock.delay_by(delta)
@ -171,11 +173,13 @@ class AsyncWebhookAdapter:
return data return data
if response.status == 429: if response.status == 429:
if not response.headers.get('Via'): if not response.headers.get("Via"):
raise HTTPException(response, data) raise HTTPException(response, data)
retry_after: float = data['retry_after'] # type: ignore retry_after: float = data["retry_after"] # type: ignore
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after) _log.warning(
"Webhook ID %s is rate limited. Retrying in %.2f seconds", webhook_id, retry_after
)
await asyncio.sleep(retry_after) await asyncio.sleep(retry_after)
continue continue
@ -201,7 +205,7 @@ class AsyncWebhookAdapter:
raise DiscordServerError(response, data) raise DiscordServerError(response, data)
raise HTTPException(response, data) raise HTTPException(response, data)
raise RuntimeError('Unreachable code in HTTP handling.') raise RuntimeError("Unreachable code in HTTP handling.")
def delete_webhook( def delete_webhook(
self, self,
@ -211,7 +215,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Response[None]: ) -> Response[None]:
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route("DELETE", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token) return self.request(route, session, reason=reason, auth_token=token)
def delete_webhook_with_token( def delete_webhook_with_token(
@ -222,7 +226,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Response[None]: ) -> Response[None]:
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route("DELETE", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason) return self.request(route, session, reason=reason)
def edit_webhook( def edit_webhook(
@ -234,7 +238,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Response[WebhookPayload]: ) -> Response[WebhookPayload]:
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token) return self.request(route, session, reason=reason, payload=payload, auth_token=token)
def edit_webhook_with_token( def edit_webhook_with_token(
@ -246,7 +250,7 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
reason: Optional[str] = None, reason: Optional[str] = None,
) -> Response[WebhookPayload]: ) -> Response[WebhookPayload]:
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route("PATCH", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload) return self.request(route, session, reason=reason, payload=payload)
def execute_webhook( def execute_webhook(
@ -261,10 +265,10 @@ class AsyncWebhookAdapter:
thread_id: Optional[int] = None, thread_id: Optional[int] = None,
wait: bool = False, wait: bool = False,
) -> Response[Optional[MessagePayload]]: ) -> Response[Optional[MessagePayload]]:
params = {'wait': int(wait)} params = {"wait": int(wait)}
if thread_id: if thread_id:
params['thread_id'] = thread_id params["thread_id"] = thread_id
route = Route('POST', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route("POST", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params) return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
def get_webhook_message( def get_webhook_message(
@ -276,8 +280,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
) -> Response[MessagePayload]: ) -> Response[MessagePayload]:
route = Route( route = Route(
'GET', "GET",
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id, webhook_id=webhook_id,
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
@ -296,8 +300,8 @@ class AsyncWebhookAdapter:
files: Optional[List[File]] = None, files: Optional[List[File]] = None,
) -> Response[Message]: ) -> Response[Message]:
route = Route( route = Route(
'PATCH', "PATCH",
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id, webhook_id=webhook_id,
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
@ -313,8 +317,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
) -> Response[None]: ) -> Response[None]:
route = Route( route = Route(
'DELETE', "DELETE",
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id, webhook_id=webhook_id,
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
@ -328,7 +332,7 @@ class AsyncWebhookAdapter:
*, *,
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
) -> Response[WebhookPayload]: ) -> Response[WebhookPayload]:
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token) return self.request(route, session=session, auth_token=token)
def fetch_webhook_with_token( def fetch_webhook_with_token(
@ -338,7 +342,7 @@ class AsyncWebhookAdapter:
*, *,
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
) -> Response[WebhookPayload]: ) -> Response[WebhookPayload]:
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route("GET", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session) return self.request(route, session=session)
def create_interaction_response( def create_interaction_response(
@ -351,15 +355,15 @@ class AsyncWebhookAdapter:
data: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None,
) -> Response[None]: ) -> Response[None]:
payload: Dict[str, Any] = { payload: Dict[str, Any] = {
'type': type, "type": type,
} }
if data is not None: if data is not None:
payload['data'] = data payload["data"] = data
route = Route( route = Route(
'POST', "POST",
'/interactions/{webhook_id}/{webhook_token}/callback', "/interactions/{webhook_id}/{webhook_token}/callback",
webhook_id=interaction_id, webhook_id=interaction_id,
webhook_token=token, webhook_token=token,
) )
@ -374,8 +378,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
) -> Response[MessagePayload]: ) -> Response[MessagePayload]:
r = Route( r = Route(
'GET', "GET",
'/webhooks/{webhook_id}/{webhook_token}/messages/@original', "/webhooks/{webhook_id}/{webhook_token}/messages/@original",
webhook_id=application_id, webhook_id=application_id,
webhook_token=token, webhook_token=token,
) )
@ -392,8 +396,8 @@ class AsyncWebhookAdapter:
files: Optional[List[File]] = None, files: Optional[List[File]] = None,
) -> Response[MessagePayload]: ) -> Response[MessagePayload]:
r = Route( r = Route(
'PATCH', "PATCH",
'/webhooks/{webhook_id}/{webhook_token}/messages/@original', "/webhooks/{webhook_id}/{webhook_token}/messages/@original",
webhook_id=application_id, webhook_id=application_id,
webhook_token=token, webhook_token=token,
) )
@ -407,8 +411,8 @@ class AsyncWebhookAdapter:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
) -> Response[None]: ) -> Response[None]:
r = Route( r = Route(
'DELETE', "DELETE",
'/webhooks/{webhook_id}/{wehook_token}/messages/@original', "/webhooks/{webhook_id}/{wehook_token}/messages/@original",
webhook_id=application_id, webhook_id=application_id,
wehook_token=token, wehook_token=token,
) )
@ -437,82 +441,82 @@ def handle_message_parameters(
previous_allowed_mentions: Optional[AllowedMentions] = None, previous_allowed_mentions: Optional[AllowedMentions] = None,
) -> ExecuteWebhookParameters: ) -> ExecuteWebhookParameters:
if files is not MISSING and file is not MISSING: if files is not MISSING and file is not MISSING:
raise TypeError('Cannot mix file and files keyword arguments.') raise TypeError("Cannot mix file and files keyword arguments.")
if embeds is not MISSING and embed is not MISSING: if embeds is not MISSING and embed is not MISSING:
raise TypeError('Cannot mix embed and embeds keyword arguments.') raise TypeError("Cannot mix embed and embeds keyword arguments.")
payload = {} payload = {}
if embeds is not MISSING: if embeds is not MISSING:
if len(embeds) > 10: if len(embeds) > 10:
raise InvalidArgument('embeds has a maximum of 10 elements.') raise InvalidArgument("embeds has a maximum of 10 elements.")
payload['embeds'] = [e.to_dict() for e in embeds] payload["embeds"] = [e.to_dict() for e in embeds]
if embed is not MISSING: if embed is not MISSING:
if embed is None: if embed is None:
payload['embeds'] = [] payload["embeds"] = []
else: else:
payload['embeds'] = [embed.to_dict()] payload["embeds"] = [embed.to_dict()]
if content is not MISSING: if content is not MISSING:
if content is not None: if content is not None:
payload['content'] = str(content) payload["content"] = str(content)
else: else:
payload['content'] = None payload["content"] = None
if view is not MISSING: if view is not MISSING:
if view is not None: if view is not None:
payload['components'] = view.to_components() payload["components"] = view.to_components()
else: else:
payload['components'] = [] payload["components"] = []
payload['tts'] = tts payload["tts"] = tts
if avatar_url: if avatar_url:
payload['avatar_url'] = str(avatar_url) payload["avatar_url"] = str(avatar_url)
if username: if username:
payload['username'] = username payload["username"] = username
if ephemeral: if ephemeral:
payload['flags'] = 64 payload["flags"] = 64
if allowed_mentions: if allowed_mentions:
if previous_allowed_mentions is not None: if previous_allowed_mentions is not None:
payload['allowed_mentions'] = previous_allowed_mentions.merge(allowed_mentions).to_dict() payload["allowed_mentions"] = previous_allowed_mentions.merge(allowed_mentions).to_dict()
else: else:
payload['allowed_mentions'] = allowed_mentions.to_dict() payload["allowed_mentions"] = allowed_mentions.to_dict()
elif previous_allowed_mentions is not None: elif previous_allowed_mentions is not None:
payload['allowed_mentions'] = previous_allowed_mentions.to_dict() payload["allowed_mentions"] = previous_allowed_mentions.to_dict()
multipart = [] multipart = []
if file is not MISSING: if file is not MISSING:
files = [file] files = [file]
if files: if files:
multipart.append({'name': 'payload_json', 'value': utils._to_json(payload)}) multipart.append({"name": "payload_json", "value": utils._to_json(payload)})
payload = None payload = None
if len(files) == 1: if len(files) == 1:
file = files[0] file = files[0]
multipart.append( multipart.append(
{ {
'name': 'file', "name": "file",
'value': file.fp, "value": file.fp,
'filename': file.filename, "filename": file.filename,
'content_type': 'application/octet-stream', "content_type": "application/octet-stream",
} }
) )
else: else:
for index, file in enumerate(files): for index, file in enumerate(files):
multipart.append( multipart.append(
{ {
'name': f'file{index}', "name": f"file{index}",
'value': file.fp, "value": file.fp,
'filename': file.filename, "filename": file.filename,
'content_type': 'application/octet-stream', "content_type": "application/octet-stream",
} }
) )
return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files) return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files)
async_context: ContextVar[AsyncWebhookAdapter] = ContextVar('async_webhook_context', default=AsyncWebhookAdapter()) async_context: ContextVar[AsyncWebhookAdapter] = ContextVar("async_webhook_context", default=AsyncWebhookAdapter())
class PartialWebhookChannel(Hashable): class PartialWebhookChannel(Hashable):
@ -530,14 +534,14 @@ class PartialWebhookChannel(Hashable):
The partial channel's name. The partial channel's name.
""" """
__slots__ = ('id', 'name') __slots__ = ("id", "name")
def __init__(self, *, data): def __init__(self, *, data):
self.id = int(data['id']) self.id = int(data["id"])
self.name = data['name'] self.name = data["name"]
def __repr__(self): def __repr__(self):
return f'<PartialWebhookChannel name={self.name!r} id={self.id}>' return f"<PartialWebhookChannel name={self.name!r} id={self.id}>"
class PartialWebhookGuild(Hashable): class PartialWebhookGuild(Hashable):
@ -555,16 +559,16 @@ class PartialWebhookGuild(Hashable):
The partial guild's name. The partial guild's name.
""" """
__slots__ = ('id', 'name', '_icon', '_state') __slots__ = ("id", "name", "_icon", "_state")
def __init__(self, *, data, state): def __init__(self, *, data, state):
self._state = state self._state = state
self.id = int(data['id']) self.id = int(data["id"])
self.name = data['name'] self.name = data["name"]
self._icon = data['icon'] self._icon = data["icon"]
def __repr__(self): def __repr__(self):
return f'<PartialWebhookGuild name={self.name!r} id={self.id}>' return f"<PartialWebhookGuild name={self.name!r} id={self.id}>"
@property @property
def icon(self) -> Optional[Asset]: def icon(self) -> Optional[Asset]:
@ -578,11 +582,11 @@ class _FriendlyHttpAttributeErrorHelper:
__slots__ = () __slots__ = ()
def __getattr__(self, attr): def __getattr__(self, attr):
raise AttributeError('PartialWebhookState does not support http methods.') raise AttributeError("PartialWebhookState does not support http methods.")
class _WebhookState: class _WebhookState:
__slots__ = ('_parent', '_webhook') __slots__ = ("_parent", "_webhook")
def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]): def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]):
self._webhook: Any = webhook self._webhook: Any = webhook
@ -621,7 +625,7 @@ class _WebhookState:
if self._parent is not None: if self._parent is not None:
return getattr(self._parent, attr) return getattr(self._parent, attr)
raise AttributeError(f'PartialWebhookState does not support {attr!r}.') raise AttributeError(f"PartialWebhookState does not support {attr!r}.")
class WebhookMessage(Message): class WebhookMessage(Message):
@ -750,18 +754,18 @@ class WebhookMessage(Message):
class BaseWebhook(Hashable): class BaseWebhook(Hashable):
__slots__: Tuple[str, ...] = ( __slots__: Tuple[str, ...] = (
'id', "id",
'type', "type",
'guild_id', "guild_id",
'channel_id', "channel_id",
'token', "token",
'auth_token', "auth_token",
'user', "user",
'name', "name",
'_avatar', "_avatar",
'source_channel', "source_channel",
'source_guild', "source_guild",
'_state', "_state",
) )
def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None): def __init__(self, data: WebhookPayload, token: Optional[str] = None, state: Optional[ConnectionState] = None):
@ -770,27 +774,27 @@ class BaseWebhook(Hashable):
self._update(data) self._update(data)
def _update(self, data: WebhookPayload): def _update(self, data: WebhookPayload):
self.id = int(data['id']) self.id = int(data["id"])
self.type = try_enum(WebhookType, int(data['type'])) self.type = try_enum(WebhookType, int(data["type"]))
self.channel_id = utils._get_as_snowflake(data, 'channel_id') self.channel_id = utils._get_as_snowflake(data, "channel_id")
self.guild_id = utils._get_as_snowflake(data, 'guild_id') self.guild_id = utils._get_as_snowflake(data, "guild_id")
self.name = data.get('name') self.name = data.get("name")
self._avatar = data.get('avatar') self._avatar = data.get("avatar")
self.token = data.get('token') self.token = data.get("token")
user = data.get('user') user = data.get("user")
self.user: Optional[Union[BaseUser, User]] = None self.user: Optional[Union[BaseUser, User]] = None
if user is not None: if user is not None:
# state parameter may be _WebhookState # state parameter may be _WebhookState
self.user = User(state=self._state, data=user) # type: ignore self.user = User(state=self._state, data=user) # type: ignore
source_channel = data.get('source_channel') source_channel = data.get("source_channel")
if source_channel: if source_channel:
source_channel = PartialWebhookChannel(data=source_channel) source_channel = PartialWebhookChannel(data=source_channel)
self.source_channel: Optional[PartialWebhookChannel] = source_channel self.source_channel: Optional[PartialWebhookChannel] = source_channel
source_guild = data.get('source_guild') source_guild = data.get("source_guild")
if source_guild: if source_guild:
source_guild = PartialWebhookGuild(data=source_guild, state=self._state) source_guild = PartialWebhookGuild(data=source_guild, state=self._state)
@ -886,6 +890,10 @@ class Webhook(BaseWebhook):
Returns the webhooks's hash. Returns the webhooks's hash.
.. describe:: int(x)
Returns the webhooks's ID.
.. versionchanged:: 1.4 .. versionchanged:: 1.4
Webhooks are now comparable and hashable. Webhooks are now comparable and hashable.
@ -923,22 +931,24 @@ class Webhook(BaseWebhook):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
__slots__: Tuple[str, ...] = ('session',) __slots__: Tuple[str, ...] = ("session",)
def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None): def __init__(self, data: WebhookPayload, session: aiohttp.ClientSession, token: Optional[str] = None, state=None):
super().__init__(data, token, state) super().__init__(data, token, state)
self.session = session self.session = session
def __repr__(self): def __repr__(self):
return f'<Webhook id={self.id!r}>' return f"<Webhook id={self.id!r}>"
@property @property
def url(self) -> str: def url(self) -> str:
""":class:`str` : Returns the webhook's url.""" """:class:`str` : Returns the webhook's url."""
return f'https://discord.com/api/webhooks/{self.id}/{self.token}' return f"https://discord.com/api/webhooks/{self.id}/{self.token}"
@classmethod @classmethod
def partial(cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None) -> Webhook: def partial(
cls, id: int, token: str, *, session: aiohttp.ClientSession, bot_token: Optional[str] = None
) -> Webhook:
"""Creates a partial :class:`Webhook`. """Creates a partial :class:`Webhook`.
Parameters Parameters
@ -966,9 +976,9 @@ class Webhook(BaseWebhook):
A partial webhook is just a webhook object with an ID and a token. A partial webhook is just a webhook object with an ID and a token.
""" """
data: WebhookPayload = { data: WebhookPayload = {
'id': id, "id": id,
'type': 1, "type": 1,
'token': token, "token": token,
} }
return cls(data, session, token=bot_token) return cls(data, session, token=bot_token)
@ -1004,24 +1014,24 @@ class Webhook(BaseWebhook):
A partial :class:`Webhook`. A partial :class:`Webhook`.
A partial webhook is just a webhook object with an ID and a token. A partial webhook is just a webhook object with an ID and a token.
""" """
m = re.search(r'discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})', url) m = re.search(r"discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})", url)
if m is None: if m is None:
raise InvalidArgument('Invalid webhook URL given.') raise InvalidArgument("Invalid webhook URL given.")
data: Dict[str, Any] = m.groupdict() data: Dict[str, Any] = m.groupdict()
data['type'] = 1 data["type"] = 1
return cls(data, session, token=bot_token) # type: ignore return cls(data, session, token=bot_token) # type: ignore
@classmethod @classmethod
def _as_follower(cls, data, *, channel, user) -> Webhook: def _as_follower(cls, data, *, channel, user) -> Webhook:
name = f"{channel.guild} #{channel}" name = f"{channel.guild} #{channel}"
feed: WebhookPayload = { feed: WebhookPayload = {
'id': data['webhook_id'], "id": data["webhook_id"],
'type': 2, "type": 2,
'name': name, "name": name,
'channel_id': channel.id, "channel_id": channel.id,
'guild_id': channel.guild.id, "guild_id": channel.guild.id,
'user': {'username': user.name, 'discriminator': user.discriminator, 'id': user.id, 'avatar': user._avatar}, "user": {"username": user.name, "discriminator": user.discriminator, "id": user.id, "avatar": user._avatar},
} }
state = channel._state state = channel._state
@ -1075,7 +1085,7 @@ class Webhook(BaseWebhook):
elif self.token: elif self.token:
data = await adapter.fetch_webhook_with_token(self.id, self.token, session=self.session) data = await adapter.fetch_webhook_with_token(self.id, self.token, session=self.session)
else: else:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
return Webhook(data, self.session, token=self.auth_token, state=self._state) return Webhook(data, self.session, token=self.auth_token, state=self._state)
@ -1108,7 +1118,7 @@ class Webhook(BaseWebhook):
This webhook does not have a token associated with it. This webhook does not have a token associated with it.
""" """
if self.token is None and self.auth_token is None: if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
adapter = async_context.get() adapter = async_context.get()
@ -1161,14 +1171,14 @@ class Webhook(BaseWebhook):
or it tried editing a channel without authentication. or it tried editing a channel without authentication.
""" """
if self.token is None and self.auth_token is None: if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
payload = {} payload = {}
if name is not MISSING: if name is not MISSING:
payload['name'] = str(name) if name is not None else None payload["name"] = str(name) if name is not None else None
if avatar is not MISSING: if avatar is not MISSING:
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
adapter = async_context.get() adapter = async_context.get()
@ -1176,27 +1186,31 @@ class Webhook(BaseWebhook):
# If a channel is given, always use the authenticated endpoint # If a channel is given, always use the authenticated endpoint
if channel is not None: if channel is not None:
if self.auth_token is None: if self.auth_token is None:
raise InvalidArgument('Editing channel requires authenticated webhook') raise InvalidArgument("Editing channel requires authenticated webhook")
payload['channel_id'] = channel.id payload["channel_id"] = channel.id
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason) data = await adapter.edit_webhook(
self.id, self.auth_token, payload=payload, session=self.session, reason=reason
)
if prefer_auth and self.auth_token: if prefer_auth and self.auth_token:
data = await adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason) data = await adapter.edit_webhook(
self.id, self.auth_token, payload=payload, session=self.session, reason=reason
)
elif self.token: elif self.token:
data = await adapter.edit_webhook_with_token( data = await adapter.edit_webhook_with_token(
self.id, self.token, payload=payload, session=self.session, reason=reason self.id, self.token, payload=payload, session=self.session, reason=reason
) )
if data is None: if data is None:
raise RuntimeError('Unreachable code hit: data was not assigned') raise RuntimeError("Unreachable code hit: data was not assigned")
return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state) return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data): def _create_message(self, data):
state = _WebhookState(self, parent=self._state) state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...) # state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore channel = self.channel or PartialMessageable(state=self._state, id=int(data["channel_id"])) # type: ignore
# state is artificial # state is artificial
return WebhookMessage(data=data, state=state, channel=channel) # type: ignore return WebhookMessage(data=data, state=state, channel=channel) # type: ignore
@ -1346,22 +1360,22 @@ class Webhook(BaseWebhook):
""" """
if self.token is None: if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None) previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
if content is None: if content is None:
content = MISSING content = MISSING
application_webhook = self.type is WebhookType.application application_webhook = self.type is WebhookType.application
if ephemeral and not application_webhook: if ephemeral and not application_webhook:
raise InvalidArgument('ephemeral messages can only be sent from application webhooks') raise InvalidArgument("ephemeral messages can only be sent from application webhooks")
if application_webhook: if application_webhook:
wait = True wait = True
if view is not MISSING: if view is not MISSING:
if isinstance(self._state, _WebhookState): if isinstance(self._state, _WebhookState):
raise InvalidArgument('Webhook views require an associated state with the webhook') raise InvalidArgument("Webhook views require an associated state with the webhook")
if ephemeral is True and view.timeout is None: if ephemeral is True and view.timeout is None:
view.timeout = 15 * 60.0 view.timeout = 15 * 60.0
@ -1435,7 +1449,7 @@ class Webhook(BaseWebhook):
""" """
if self.token is None: if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
adapter = async_context.get() adapter = async_context.get()
data = await adapter.get_webhook_message( data = await adapter.get_webhook_message(
@ -1521,15 +1535,15 @@ class Webhook(BaseWebhook):
""" """
if self.token is None: if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
if view is not MISSING: if view is not MISSING:
if isinstance(self._state, _WebhookState): if isinstance(self._state, _WebhookState):
raise InvalidArgument('This webhook does not have state associated with it') raise InvalidArgument("This webhook does not have state associated with it")
self._state.prevent_view_updates_for(message_id) self._state.prevent_view_updates_for(message_id)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None) previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
params = handle_message_parameters( params = handle_message_parameters(
content=content, content=content,
file=file, file=file,
@ -1579,7 +1593,7 @@ class Webhook(BaseWebhook):
Deleted a message that is not yours. Deleted a message that is not yours.
""" """
if self.token is None: if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
adapter = async_context.get() adapter = async_context.get()
await adapter.delete_webhook_message( await adapter.delete_webhook_message(

View File

@ -48,8 +48,8 @@ from ..channel import PartialMessageable
from .async_ import BaseWebhook, handle_message_parameters, _WebhookState from .async_ import BaseWebhook, handle_message_parameters, _WebhookState
__all__ = ( __all__ = (
'SyncWebhook', "SyncWebhook",
'SyncWebhookMessage', "SyncWebhookMessage",
) )
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -116,14 +116,14 @@ class WebhookAdapter:
self._locks[bucket] = lock = threading.Lock() self._locks[bucket] = lock = threading.Lock()
if payload is not None: if payload is not None:
headers['Content-Type'] = 'application/json' headers["Content-Type"] = "application/json"
to_send = utils._to_json(payload) to_send = utils._to_json(payload)
if auth_token is not None: if auth_token is not None:
headers['Authorization'] = f'Bot {auth_token}' headers["Authorization"] = f"Bot {auth_token}"
if reason is not None: if reason is not None:
headers['X-Audit-Log-Reason'] = urlquote(reason, safe='/ ') headers["X-Audit-Log-Reason"] = urlquote(reason, safe="/ ")
response: Optional[Response] = None response: Optional[Response] = None
data: Optional[Union[Dict[str, Any], str]] = None data: Optional[Union[Dict[str, Any], str]] = None
@ -140,36 +140,38 @@ class WebhookAdapter:
if multipart: if multipart:
file_data = {} file_data = {}
for p in multipart: for p in multipart:
name = p['name'] name = p["name"]
if name == 'payload_json': if name == "payload_json":
to_send = {'payload_json': p['value']} to_send = {"payload_json": p["value"]}
else: else:
file_data[name] = (p['filename'], p['value'], p['content_type']) file_data[name] = (p["filename"], p["value"], p["content_type"])
try: try:
with session.request( with session.request(
method, url, data=to_send, files=file_data, headers=headers, params=params method, url, data=to_send, files=file_data, headers=headers, params=params
) as response: ) as response:
_log.debug( _log.debug(
'Webhook ID %s with %s %s has returned status code %s', "Webhook ID %s with %s %s has returned status code %s",
webhook_id, webhook_id,
method, method,
url, url,
response.status_code, response.status_code,
) )
response.encoding = 'utf-8' response.encoding = "utf-8"
# Compatibility with aiohttp # Compatibility with aiohttp
response.status = response.status_code # type: ignore response.status = response.status_code # type: ignore
data = response.text or None data = response.text or None
if data and response.headers['Content-Type'] == 'application/json': if data and response.headers["Content-Type"] == "application/json":
data = json.loads(data) data = json.loads(data)
remaining = response.headers.get('X-Ratelimit-Remaining') remaining = response.headers.get("X-Ratelimit-Remaining")
if remaining == '0' and response.status_code != 429: if remaining == "0" and response.status_code != 429:
delta = utils._parse_ratelimit_header(response) delta = utils._parse_ratelimit_header(response)
_log.debug( _log.debug(
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta "Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds",
webhook_id,
delta,
) )
lock.delay_by(delta) lock.delay_by(delta)
@ -177,11 +179,13 @@ class WebhookAdapter:
return data return data
if response.status_code == 429: if response.status_code == 429:
if not response.headers.get('Via'): if not response.headers.get("Via"):
raise HTTPException(response, data) raise HTTPException(response, data)
retry_after: float = data['retry_after'] # type: ignore retry_after: float = data["retry_after"] # type: ignore
_log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after) _log.warning(
"Webhook ID %s is rate limited. Retrying in %.2f seconds", webhook_id, retry_after
)
time.sleep(retry_after) time.sleep(retry_after)
continue continue
@ -207,7 +211,7 @@ class WebhookAdapter:
raise DiscordServerError(response, data) raise DiscordServerError(response, data)
raise HTTPException(response, data) raise HTTPException(response, data)
raise RuntimeError('Unreachable code in HTTP handling.') raise RuntimeError("Unreachable code in HTTP handling.")
def delete_webhook( def delete_webhook(
self, self,
@ -217,7 +221,7 @@ class WebhookAdapter:
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ):
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route("DELETE", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token) return self.request(route, session, reason=reason, auth_token=token)
def delete_webhook_with_token( def delete_webhook_with_token(
@ -228,7 +232,7 @@ class WebhookAdapter:
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ):
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route("DELETE", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason) return self.request(route, session, reason=reason)
def edit_webhook( def edit_webhook(
@ -240,7 +244,7 @@ class WebhookAdapter:
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ):
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token) return self.request(route, session, reason=reason, payload=payload, auth_token=token)
def edit_webhook_with_token( def edit_webhook_with_token(
@ -252,7 +256,7 @@ class WebhookAdapter:
session: Session, session: Session,
reason: Optional[str] = None, reason: Optional[str] = None,
): ):
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route("PATCH", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload) return self.request(route, session, reason=reason, payload=payload)
def execute_webhook( def execute_webhook(
@ -267,10 +271,10 @@ class WebhookAdapter:
thread_id: Optional[int] = None, thread_id: Optional[int] = None,
wait: bool = False, wait: bool = False,
): ):
params = {'wait': int(wait)} params = {"wait": int(wait)}
if thread_id: if thread_id:
params['thread_id'] = thread_id params["thread_id"] = thread_id
route = Route('POST', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route("POST", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params) return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
def get_webhook_message( def get_webhook_message(
@ -282,8 +286,8 @@ class WebhookAdapter:
session: Session, session: Session,
): ):
route = Route( route = Route(
'GET', "GET",
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id, webhook_id=webhook_id,
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
@ -302,8 +306,8 @@ class WebhookAdapter:
files: Optional[List[File]] = None, files: Optional[List[File]] = None,
): ):
route = Route( route = Route(
'PATCH', "PATCH",
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id, webhook_id=webhook_id,
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
@ -319,8 +323,8 @@ class WebhookAdapter:
session: Session, session: Session,
): ):
route = Route( route = Route(
'DELETE', "DELETE",
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}', "/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
webhook_id=webhook_id, webhook_id=webhook_id,
webhook_token=token, webhook_token=token,
message_id=message_id, message_id=message_id,
@ -334,7 +338,7 @@ class WebhookAdapter:
*, *,
session: Session, session: Session,
): ):
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id) route = Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token) return self.request(route, session=session, auth_token=token)
def fetch_webhook_with_token( def fetch_webhook_with_token(
@ -344,7 +348,7 @@ class WebhookAdapter:
*, *,
session: Session, session: Session,
): ):
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token) route = Route("GET", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session) return self.request(route, session=session)
@ -475,6 +479,10 @@ class SyncWebhook(BaseWebhook):
Returns the webhooks's hash. Returns the webhooks's hash.
.. describe:: int(x)
Returns the webhooks's ID.
.. versionchanged:: 1.4 .. versionchanged:: 1.4
Webhooks are now comparable and hashable. Webhooks are now comparable and hashable.
@ -512,22 +520,24 @@ class SyncWebhook(BaseWebhook):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
__slots__: Tuple[str, ...] = ('session',) __slots__: Tuple[str, ...] = ("session",)
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None): def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None):
super().__init__(data, token, state) super().__init__(data, token, state)
self.session = session self.session = session
def __repr__(self): def __repr__(self):
return f'<Webhook id={self.id!r}>' return f"<Webhook id={self.id!r}>"
@property @property
def url(self) -> str: def url(self) -> str:
""":class:`str` : Returns the webhook's url.""" """:class:`str` : Returns the webhook's url."""
return f'https://discord.com/api/webhooks/{self.id}/{self.token}' return f"https://discord.com/api/webhooks/{self.id}/{self.token}"
@classmethod @classmethod
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook: def partial(
cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None
) -> SyncWebhook:
"""Creates a partial :class:`Webhook`. """Creates a partial :class:`Webhook`.
Parameters Parameters
@ -552,15 +562,15 @@ class SyncWebhook(BaseWebhook):
A partial webhook is just a webhook object with an ID and a token. A partial webhook is just a webhook object with an ID and a token.
""" """
data: WebhookPayload = { data: WebhookPayload = {
'id': id, "id": id,
'type': 1, "type": 1,
'token': token, "token": token,
} }
import requests import requests
if session is not MISSING: if session is not MISSING:
if not isinstance(session, requests.Session): if not isinstance(session, requests.Session):
raise TypeError(f'expected requests.Session not {session.__class__!r}') raise TypeError(f"expected requests.Session not {session.__class__!r}")
else: else:
session = requests # type: ignore session = requests # type: ignore
return cls(data, session, token=bot_token) return cls(data, session, token=bot_token)
@ -593,17 +603,17 @@ class SyncWebhook(BaseWebhook):
A partial :class:`Webhook`. A partial :class:`Webhook`.
A partial webhook is just a webhook object with an ID and a token. A partial webhook is just a webhook object with an ID and a token.
""" """
m = re.search(r'discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})', url) m = re.search(r"discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})", url)
if m is None: if m is None:
raise InvalidArgument('Invalid webhook URL given.') raise InvalidArgument("Invalid webhook URL given.")
data: Dict[str, Any] = m.groupdict() data: Dict[str, Any] = m.groupdict()
data['type'] = 1 data["type"] = 1
import requests import requests
if session is not MISSING: if session is not MISSING:
if not isinstance(session, requests.Session): if not isinstance(session, requests.Session):
raise TypeError(f'expected requests.Session not {session.__class__!r}') raise TypeError(f"expected requests.Session not {session.__class__!r}")
else: else:
session = requests # type: ignore session = requests # type: ignore
return cls(data, session, token=bot_token) # type: ignore return cls(data, session, token=bot_token) # type: ignore
@ -646,7 +656,7 @@ class SyncWebhook(BaseWebhook):
elif self.token: elif self.token:
data = adapter.fetch_webhook_with_token(self.id, self.token, session=self.session) data = adapter.fetch_webhook_with_token(self.id, self.token, session=self.session)
else: else:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
return SyncWebhook(data, self.session, token=self.auth_token, state=self._state) return SyncWebhook(data, self.session, token=self.auth_token, state=self._state)
@ -675,7 +685,7 @@ class SyncWebhook(BaseWebhook):
This webhook does not have a token associated with it. This webhook does not have a token associated with it.
""" """
if self.token is None and self.auth_token is None: if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
adapter: WebhookAdapter = _get_webhook_adapter() adapter: WebhookAdapter = _get_webhook_adapter()
@ -727,14 +737,14 @@ class SyncWebhook(BaseWebhook):
The newly edited webhook. The newly edited webhook.
""" """
if self.token is None and self.auth_token is None: if self.token is None and self.auth_token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
payload = {} payload = {}
if name is not MISSING: if name is not MISSING:
payload['name'] = str(name) if name is not None else None payload["name"] = str(name) if name is not None else None
if avatar is not MISSING: if avatar is not MISSING:
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
adapter: WebhookAdapter = _get_webhook_adapter() adapter: WebhookAdapter = _get_webhook_adapter()
@ -742,25 +752,27 @@ class SyncWebhook(BaseWebhook):
# If a channel is given, always use the authenticated endpoint # If a channel is given, always use the authenticated endpoint
if channel is not None: if channel is not None:
if self.auth_token is None: if self.auth_token is None:
raise InvalidArgument('Editing channel requires authenticated webhook') raise InvalidArgument("Editing channel requires authenticated webhook")
payload['channel_id'] = channel.id payload["channel_id"] = channel.id
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason) data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
if prefer_auth and self.auth_token: if prefer_auth and self.auth_token:
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason) data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
elif self.token: elif self.token:
data = adapter.edit_webhook_with_token(self.id, self.token, payload=payload, session=self.session, reason=reason) data = adapter.edit_webhook_with_token(
self.id, self.token, payload=payload, session=self.session, reason=reason
)
if data is None: if data is None:
raise RuntimeError('Unreachable code hit: data was not assigned') raise RuntimeError("Unreachable code hit: data was not assigned")
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state) return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
def _create_message(self, data): def _create_message(self, data):
state = _WebhookState(self, parent=self._state) state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...) # state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data['channel_id'])) # type: ignore channel = self.channel or PartialMessageable(state=self._state, id=int(data["channel_id"])) # type: ignore
# state is artificial # state is artificial
return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore
@ -883,9 +895,9 @@ class SyncWebhook(BaseWebhook):
""" """
if self.token is None: if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None) previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
if content is None: if content is None:
content = MISSING content = MISSING
@ -947,7 +959,7 @@ class SyncWebhook(BaseWebhook):
""" """
if self.token is None: if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
adapter: WebhookAdapter = _get_webhook_adapter() adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.get_webhook_message( data = adapter.get_webhook_message(
@ -1011,9 +1023,9 @@ class SyncWebhook(BaseWebhook):
""" """
if self.token is None: if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None) previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
params = handle_message_parameters( params = handle_message_parameters(
content=content, content=content,
file=file, file=file,
@ -1056,7 +1068,7 @@ class SyncWebhook(BaseWebhook):
Deleted a message that is not yours. Deleted a message that is not yours.
""" """
if self.token is None: if self.token is None:
raise InvalidArgument('This webhook does not have a token associated with it') raise InvalidArgument("This webhook does not have a token associated with it")
adapter: WebhookAdapter = _get_webhook_adapter() adapter: WebhookAdapter = _get_webhook_adapter()
adapter.delete_webhook_message( adapter.delete_webhook_message(

View File

@ -41,11 +41,12 @@ if TYPE_CHECKING:
) )
__all__ = ( __all__ = (
'WidgetChannel', "WidgetChannel",
'WidgetMember', "WidgetMember",
'Widget', "Widget",
) )
class WidgetChannel: class WidgetChannel:
"""Represents a "partial" widget channel. """Represents a "partial" widget channel.
@ -76,7 +77,8 @@ class WidgetChannel:
position: :class:`int` position: :class:`int`
The channel's position The channel's position
""" """
__slots__ = ('id', 'name', 'position')
__slots__ = ("id", "name", "position")
def __init__(self, id: int, name: str, position: int) -> None: def __init__(self, id: int, name: str, position: int) -> None:
self.id: int = id self.id: int = id
@ -87,18 +89,19 @@ class WidgetChannel:
return self.name return self.name
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>' return f"<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>"
@property @property
def mention(self) -> str: def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel.""" """:class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>' return f"<#{self.id}>"
@property @property
def created_at(self) -> datetime.datetime: def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the channel's creation time in UTC.""" """:class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return snowflake_time(self.id) return snowflake_time(self.id)
class WidgetMember(BaseUser): class WidgetMember(BaseUser):
"""Represents a "partial" member of the widget's guild. """Represents a "partial" member of the widget's guild.
@ -147,29 +150,37 @@ class WidgetMember(BaseUser):
connected_channel: Optional[:class:`WidgetChannel`] connected_channel: Optional[:class:`WidgetChannel`]
Which channel the member is connected to. Which channel the member is connected to.
""" """
__slots__ = ('name', 'status', 'nick', 'avatar', 'discriminator',
'id', 'bot', 'activity', 'deafened', 'suppress', 'muted', __slots__ = (
'connected_channel') "name",
"status",
"nick",
"avatar",
"discriminator",
"id",
"bot",
"activity",
"deafened",
"suppress",
"muted",
"connected_channel",
)
if TYPE_CHECKING: if TYPE_CHECKING:
activity: Optional[Union[BaseActivity, Spotify]] activity: Optional[Union[BaseActivity, Spotify]]
def __init__( def __init__(
self, self, *, state: ConnectionState, data: WidgetMemberPayload, connected_channel: Optional[WidgetChannel] = None
*,
state: ConnectionState,
data: WidgetMemberPayload,
connected_channel: Optional[WidgetChannel] = None
) -> None: ) -> None:
super().__init__(state=state, data=data) super().__init__(state=state, data=data)
self.nick: Optional[str] = data.get('nick') self.nick: Optional[str] = data.get("nick")
self.status: Status = try_enum(Status, data.get('status')) self.status: Status = try_enum(Status, data.get("status"))
self.deafened: Optional[bool] = data.get('deaf', False) or data.get('self_deaf', False) self.deafened: Optional[bool] = data.get("deaf", False) or data.get("self_deaf", False)
self.muted: Optional[bool] = data.get('mute', False) or data.get('self_mute', False) self.muted: Optional[bool] = data.get("mute", False) or data.get("self_mute", False)
self.suppress: Optional[bool] = data.get('suppress', False) self.suppress: Optional[bool] = data.get("suppress", False)
try: try:
game = data['game'] game = data["game"]
except KeyError: except KeyError:
activity = None activity = None
else: else:
@ -190,6 +201,7 @@ class WidgetMember(BaseUser):
""":class:`str`: Returns the member's display name.""" """:class:`str`: Returns the member's display name."""
return self.nick or self.name return self.nick or self.name
class Widget: class Widget:
"""Represents a :class:`Guild` widget. """Represents a :class:`Guild` widget.
@ -227,27 +239,28 @@ class Widget:
retrieved is capped. retrieved is capped.
""" """
__slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name')
__slots__ = ("_state", "channels", "_invite", "id", "members", "name")
def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None: def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None:
self._state = state self._state = state
self._invite = data['instant_invite'] self._invite = data["instant_invite"]
self.name: str = data['name'] self.name: str = data["name"]
self.id: int = int(data['id']) self.id: int = int(data["id"])
self.channels: List[WidgetChannel] = [] self.channels: List[WidgetChannel] = []
for channel in data.get('channels', []): for channel in data.get("channels", []):
_id = int(channel['id']) _id = int(channel["id"])
self.channels.append(WidgetChannel(id=_id, name=channel['name'], position=channel['position'])) self.channels.append(WidgetChannel(id=_id, name=channel["name"], position=channel["position"]))
self.members: List[WidgetMember] = [] self.members: List[WidgetMember] = []
channels = {channel.id: channel for channel in self.channels} channels = {channel.id: channel for channel in self.channels}
for member in data.get('members', []): for member in data.get("members", []):
connected_channel = _get_as_snowflake(member, 'channel_id') connected_channel = _get_as_snowflake(member, "channel_id")
if connected_channel in channels: if connected_channel in channels:
connected_channel = channels[connected_channel] # type: ignore connected_channel = channels[connected_channel] # type: ignore
elif connected_channel: elif connected_channel:
connected_channel = WidgetChannel(id=connected_channel, name='', position=0) connected_channel = WidgetChannel(id=connected_channel, name="", position=0)
self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) # type: ignore self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) # type: ignore
@ -260,7 +273,7 @@ class Widget:
return False return False
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>' return f"<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>"
@property @property
def created_at(self) -> datetime.datetime: def created_at(self) -> datetime.datetime:

File diff suppressed because it is too large Load Diff

View File

@ -18,45 +18,45 @@ import re
# If extensions (or modules to document with autodoc) are in another directory, # If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the # add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('..')) sys.path.insert(0, os.path.abspath(".."))
sys.path.append(os.path.abspath('extensions')) sys.path.append(os.path.abspath("extensions"))
# -- General configuration ------------------------------------------------ # -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here. # If your documentation needs a minimal Sphinx version, state it here.
#needs_sphinx = '1.0' # needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'builder', "builder",
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.extlinks', "sphinx.ext.extlinks",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'sphinxcontrib_trio', "sphinxcontrib_trio",
'details', "details",
'exception_hierarchy', "exception_hierarchy",
'attributetable', "attributetable",
'resourcelinks', "resourcelinks",
'nitpick_file_ignorer', "nitpick_file_ignorer",
] ]
autodoc_member_order = 'bysource' autodoc_member_order = "bysource"
autodoc_typehints = 'none' autodoc_typehints = "none"
# maybe consider this? # maybe consider this?
# napoleon_attr_annotations = False # napoleon_attr_annotations = False
extlinks = { extlinks = {
'issue': ('https://github.com/Rapptz/discord.py/issues/%s', 'GH-'), "issue": ("https://github.com/iDevision/enhanced-discord.py/issues/%s", "GH-"),
} }
# Links used for cross-referencing stuff in other documentation # Links used for cross-referencing stuff in other documentation
intersphinx_mapping = { intersphinx_mapping = {
'py': ('https://docs.python.org/3', None), "py": ("https://docs.python.org/3", None),
'aio': ('https://docs.aiohttp.org/en/stable/', None), "aio": ("https://docs.aiohttp.org/en/stable/", None),
'req': ('https://docs.python-requests.org/en/latest/', None) "req": ("https://docs.python-requests.org/en/latest/", None),
} }
rst_prolog = """ rst_prolog = """
@ -67,20 +67,20 @@ rst_prolog = """
""" """
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix of source filenames. # The suffix of source filenames.
source_suffix = '.rst' source_suffix = ".rst"
# The encoding of source files. # The encoding of source files.
#source_encoding = 'utf-8-sig' # source_encoding = 'utf-8-sig'
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = 'discord.py' project = "discord.py"
copyright = '2015-present, Rapptz' copyright = "2015-present, Rapptz"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
@ -88,15 +88,15 @@ copyright = '2015-present, Rapptz'
# #
# The short X.Y version. # The short X.Y version.
version = '' version = ""
with open('../discord/__init__.py') as f: with open("../discord/__init__.py") as f:
version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1) version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1)
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = version release = version
# This assumes a tag is available for final releases # This assumes a tag is available for final releases
branch = 'master' if version.endswith('a') else 'v' + version branch = "master" if version.endswith("a") else "v" + version
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
@ -105,49 +105,49 @@ branch = 'master' if version.endswith('a') else 'v' + version
# Usually you set "language" from the command line for these cases. # Usually you set "language" from the command line for these cases.
language = None language = None
locale_dirs = ['locale/'] locale_dirs = ["locale/"]
gettext_compact = False gettext_compact = False
# There are two options for replacing |today|: either, you set today to some # There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used: # non-false value, then it is used:
#today = '' # today = ''
# Else, today_fmt is used as the format for a strftime call. # Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y' # today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
exclude_patterns = ['_build'] exclude_patterns = ["_build"]
# The reST default role (used for this markup: `text`) to use for all # The reST default role (used for this markup: `text`) to use for all
# documents. # documents.
#default_role = None # default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text. # If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True # add_function_parentheses = True
# If true, the current module name will be prepended to all description # If true, the current module name will be prepended to all description
# unit titles (such as .. function::). # unit titles (such as .. function::).
#add_module_names = True # add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the # If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default. # output. They are ignored by default.
#show_authors = False # show_authors = False
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'friendly' pygments_style = "friendly"
# A list of ignored prefixes for module index sorting. # A list of ignored prefixes for module index sorting.
#modindex_common_prefix = [] # modindex_common_prefix = []
# If true, keep warnings as "system message" paragraphs in the built documents. # If true, keep warnings as "system message" paragraphs in the built documents.
#keep_warnings = False # keep_warnings = False
# Nitpicky mode options # Nitpicky mode options
nitpick_ignore_files = [ nitpick_ignore_files = [
"migrating_to_async", "migrating_to_async",
"migrating", "migrating",
"whats_new", "whats_new",
] ]
# -- Options for HTML output ---------------------------------------------- # -- Options for HTML output ----------------------------------------------
@ -156,21 +156,20 @@ html_experimental_html5_writer = True
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
html_theme = 'basic' html_theme = "basic"
html_context = { html_context = {
'discord_invite': 'https://discord.gg/r3sSKJJ', "discord_invite": "https://discord.gg/TvqYBrGXEm",
'discord_extensions': [ "discord_extensions": [
('discord.ext.commands', 'ext/commands'), ("discord.ext.commands", "ext/commands"),
('discord.ext.tasks', 'ext/tasks'), ("discord.ext.tasks", "ext/tasks"),
], ],
} }
resource_links = { resource_links = {
'discord': 'https://discord.gg/r3sSKJJ', "discord": "https://discord.gg/TvqYBrGXEm",
'issues': 'https://github.com/Rapptz/discord.py/issues', "issues": "https://github.com/iDevision/enhanced-discord.py/issues",
'discussions': 'https://github.com/Rapptz/discord.py/discussions', "examples": f"https://github.com/iDevision/enhanced-discord.py/tree/{branch}/examples",
'examples': f'https://github.com/Rapptz/discord.py/tree/{branch}/examples',
} }
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
@ -180,155 +179,143 @@ resource_links = {
# } # }
# Add any paths that contain custom themes here, relative to this directory. # Add any paths that contain custom themes here, relative to this directory.
#html_theme_path = [] # html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to # The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation". # "<project> v<release> documentation".
#html_title = None # html_title = None
# A shorter title for the navigation bar. Default is the same as html_title. # A shorter title for the navigation bar. Default is the same as html_title.
#html_short_title = None # html_short_title = None
# The name of an image file (relative to this directory) to place at the top # The name of an image file (relative to this directory) to place at the top
# of the sidebar. # of the sidebar.
#html_logo = None # html_logo = None
# The name of an image file (within the static path) to use as favicon of the # The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large. # pixels large.
html_favicon = './images/discord_py_logo.ico' html_favicon = "./images/discord_py_logo.ico"
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
# Add any extra paths that contain custom files (such as robots.txt or # Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied # .htaccess) here, relative to this directory. These files are copied
# directly to the root of the documentation. # directly to the root of the documentation.
#html_extra_path = [] # html_extra_path = []
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format. # using the given strftime format.
#html_last_updated_fmt = '%b %d, %Y' # html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to # If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities. # typographically correct entities.
#html_use_smartypants = True # html_use_smartypants = True
# Custom sidebar templates, maps document names to template names. # Custom sidebar templates, maps document names to template names.
#html_sidebars = {} # html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names to # Additional templates that should be rendered to pages, maps page names to
# template names. # template names.
#html_additional_pages = {} # html_additional_pages = {}
# If false, no module index is generated. # If false, no module index is generated.
#html_domain_indices = True # html_domain_indices = True
# If false, no index is generated. # If false, no index is generated.
#html_use_index = True # html_use_index = True
# If true, the index is split into individual pages for each letter. # If true, the index is split into individual pages for each letter.
#html_split_index = False # html_split_index = False
# If true, links to the reST sources are added to the pages. # If true, links to the reST sources are added to the pages.
#html_show_sourcelink = True # html_show_sourcelink = True
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
#html_show_sphinx = True # html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
#html_show_copyright = True # html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will # If true, an OpenSearch description file will be output, and all pages will
# contain a <link> tag referring to it. The value of this option must be the # contain a <link> tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served. # base URL from which the finished HTML is served.
#html_use_opensearch = '' # html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml"). # This is the file name suffix for HTML files (e.g. ".xhtml").
#html_file_suffix = None # html_file_suffix = None
# Language to be used for generating the HTML full-text search index. # Language to be used for generating the HTML full-text search index.
# Sphinx supports the following languages: # Sphinx supports the following languages:
# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja'
# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr'
#html_search_language = 'en' # html_search_language = 'en'
# A dictionary with options for the search language support, empty by default. # A dictionary with options for the search language support, empty by default.
# Now only 'ja' uses this config value # Now only 'ja' uses this config value
#html_search_options = {'type': 'default'} # html_search_options = {'type': 'default'}
# The name of a javascript file (relative to the configuration directory) that # The name of a javascript file (relative to the configuration directory) that
# implements a search results scorer. If empty, the default will be used. # implements a search results scorer. If empty, the default will be used.
html_search_scorer = '_static/scorer.js' html_search_scorer = "_static/scorer.js"
html_js_files = [ html_js_files = ["custom.js", "settings.js", "copy.js", "sidebar.js"]
'custom.js',
'settings.js',
'copy.js',
'sidebar.js'
]
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'discord.pydoc' htmlhelp_basename = "discord.pydoc"
# -- Options for LaTeX output --------------------------------------------- # -- Options for LaTeX output ---------------------------------------------
latex_elements = { latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper', #'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
# The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt',
#'pointsize': '10pt', # Additional stuff for the LaTeX preamble.
#'preamble': '',
# Additional stuff for the LaTeX preamble. # Latex figure (float) alignment
#'preamble': '', #'figure_align': 'htbp',
# Latex figure (float) alignment
#'figure_align': 'htbp',
} }
# Grouping the document tree into LaTeX files. List of tuples # Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
('index', 'discord.py.tex', 'discord.py Documentation', ("index", "discord.py.tex", "discord.py Documentation", "Rapptz", "manual"),
'Rapptz', 'manual'),
] ]
# The name of an image file (relative to this directory) to place at the top of # The name of an image file (relative to this directory) to place at the top of
# the title page. # the title page.
#latex_logo = None # latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts, # For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters. # not chapters.
#latex_use_parts = False # latex_use_parts = False
# If true, show page references after internal links. # If true, show page references after internal links.
#latex_show_pagerefs = False # latex_show_pagerefs = False
# If true, show URL addresses after external links. # If true, show URL addresses after external links.
#latex_show_urls = False # latex_show_urls = False
# Documents to append as an appendix to all manuals. # Documents to append as an appendix to all manuals.
#latex_appendices = [] # latex_appendices = []
# If false, no module index is generated. # If false, no module index is generated.
#latex_domain_indices = True # latex_domain_indices = True
# -- Options for manual page output --------------------------------------- # -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [("index", "discord.py", "discord.py Documentation", ["Rapptz"], 1)]
('index', 'discord.py', 'discord.py Documentation',
['Rapptz'], 1)
]
# If true, show URL addresses after external links. # If true, show URL addresses after external links.
#man_show_urls = False # man_show_urls = False
# -- Options for Texinfo output ------------------------------------------- # -- Options for Texinfo output -------------------------------------------
@ -337,25 +324,32 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
('index', 'discord.py', 'discord.py Documentation', (
'Rapptz', 'discord.py', 'One line description of project.', "index",
'Miscellaneous'), "discord.py",
"discord.py Documentation",
"Rapptz",
"discord.py",
"One line description of project.",
"Miscellaneous",
),
] ]
# Documents to append as an appendix to all manuals. # Documents to append as an appendix to all manuals.
#texinfo_appendices = [] # texinfo_appendices = []
# If false, no module index is generated. # If false, no module index is generated.
#texinfo_domain_indices = True # texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'. # How to display URL addresses: 'footnote', 'no', or 'inline'.
#texinfo_show_urls = 'footnote' # texinfo_show_urls = 'footnote'
# If true, do not generate a @detailmenu in the "Top" node's menu. # If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False # texinfo_no_detailmenu = False
def setup(app): def setup(app):
if app.config.language == 'ja': if app.config.language == "ja":
app.config.intersphinx_mapping['py'] = ('https://docs.python.org/ja/3', None) app.config.intersphinx_mapping["py"] = ("https://docs.python.org/ja/3", None)
app.config.html_context['discord_invite'] = 'https://discord.gg/nXzj3dg' app.config.html_context["discord_invite"] = "https://discord.gg/TvqYBrGXEm"
app.config.resource_links['discord'] = 'https://discord.gg/nXzj3dg' app.config.resource_links["discord"] = "https://discord.gg/TvqYBrGXEm"

View File

@ -429,6 +429,12 @@ Converters
.. autofunction:: discord.ext.commands.run_converters .. autofunction:: discord.ext.commands.run_converters
Option
~~~~~~
.. autoclass:: discord.ext.commands.Option
:members:
Flag Converter Flag Converter
~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~

View File

@ -61,6 +61,7 @@ the name to something other than the function would be as simple as doing this:
async def _list(ctx, arg): async def _list(ctx, arg):
pass pass
Parameters Parameters
------------ ------------
@ -133,6 +134,11 @@ at all:
Since the ``args`` variable is a :class:`py:tuple`, Since the ``args`` variable is a :class:`py:tuple`,
you can do anything you would usually do with one. you can do anything you would usually do with one.
.. admonition:: Slash Command Only
This functionally is currently not supported by the slash command API, so is turned into
a single ``STRING`` parameter on discord's end which we do our own parsing on.
Keyword-Only Arguments Keyword-Only Arguments
++++++++++++++++++++++++ ++++++++++++++++++++++++
@ -179,6 +185,12 @@ know how the command was executed. It contains a lot of useful information:
The context implements the :class:`abc.Messageable` interface, so anything you can do on a :class:`abc.Messageable` you The context implements the :class:`abc.Messageable` interface, so anything you can do on a :class:`abc.Messageable` you
can do on the :class:`~ext.commands.Context`. can do on the :class:`~ext.commands.Context`.
.. admonition:: Slash Command Only
:attr:`.Context.message` will be fake if in a slash command, it is not
recommended to access if :attr:`.Context.interaction` is not None as most
methods will error due to the message not actually existing.
Converters Converters
------------ ------------
@ -400,47 +412,55 @@ specify.
Under the hood, these are implemented by the :ref:`ext_commands_adv_converters` interface. A table of the equivalent Under the hood, these are implemented by the :ref:`ext_commands_adv_converters` interface. A table of the equivalent
converter is given below: converter is given below:
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| Discord Class | Converter | | Discord Class | Converter | Supported By Slash Commands |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Object` | :class:`~ext.commands.ObjectConverter` | | :class:`Object` | :class:`~ext.commands.ObjectConverter` | Not currently |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Member` | :class:`~ext.commands.MemberConverter` | | :class:`Member` | :class:`~ext.commands.MemberConverter` | Yes, as type 6 (USER) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`User` | :class:`~ext.commands.UserConverter` | | :class:`User` | :class:`~ext.commands.UserConverter` | Yes, as type 6 (USER) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Message` | :class:`~ext.commands.MessageConverter` | | :class:`Message` | :class:`~ext.commands.MessageConverter` | Not currently |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`PartialMessage` | :class:`~ext.commands.PartialMessageConverter` | | :class:`PartialMessage` | :class:`~ext.commands.PartialMessageConverter` | Not currently |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`.GuildChannel` | :class:`~ext.commands.GuildChannelConverter` | | :class:`.GuildChannel` | :class:`~ext.commands.GuildChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`TextChannel` | :class:`~ext.commands.TextChannelConverter` | | :class:`TextChannel` | :class:`~ext.commands.TextChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`VoiceChannel` | :class:`~ext.commands.VoiceChannelConverter` | | :class:`VoiceChannel` | :class:`~ext.commands.VoiceChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`StageChannel` | :class:`~ext.commands.StageChannelConverter` | | :class:`StageChannel` | :class:`~ext.commands.StageChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`StoreChannel` | :class:`~ext.commands.StoreChannelConverter` | | :class:`StoreChannel` | :class:`~ext.commands.StoreChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`CategoryChannel` | :class:`~ext.commands.CategoryChannelConverter` | | :class:`CategoryChannel` | :class:`~ext.commands.CategoryChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Invite` | :class:`~ext.commands.InviteConverter` | | :class:`Thread` | :class:`~ext.commands.ThreadConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Guild` | :class:`~ext.commands.GuildConverter` | | :class:`Invite` | :class:`~ext.commands.InviteConverter` | Not currently |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Role` | :class:`~ext.commands.RoleConverter` | | :class:`Guild` | :class:`~ext.commands.GuildConverter` | Not currently |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Game` | :class:`~ext.commands.GameConverter` | | :class:`Role` | :class:`~ext.commands.RoleConverter` | Yes, as type 8 (ROLE) |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Colour` | :class:`~ext.commands.ColourConverter` | | :class:`Game` | :class:`~ext.commands.GameConverter` | Not currently |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Emoji` | :class:`~ext.commands.EmojiConverter` | | :class:`Colour` | :class:`~ext.commands.ColourConverter` | Not currently |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`PartialEmoji` | :class:`~ext.commands.PartialEmojiConverter` | | :class:`Emoji` | :class:`~ext.commands.EmojiConverter` | Not currently |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Thread` | :class:`~ext.commands.ThreadConverter` | | :class:`PartialEmoji` | :class:`~ext.commands.PartialEmojiConverter` | Not currently |
+--------------------------+-------------------------------------------------+ +--------------------------+-------------------------------------------------+-----------------------------+
.. admonition:: Slash Command Only
If a slash command is not marked on the table above as supported, it will be sent as type 3 (STRING)
and parsed by normal content parsing, see
`the discord documentation <https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-type>`_
for all supported types by the API.
By providing the converter it allows us to use them as building blocks for another converter: By providing the converter it allows us to use them as building blocks for another converter:
@ -487,6 +507,10 @@ then a special error is raised, :exc:`~ext.commands.BadUnionArgument`.
Note that any valid converter discussed above can be passed in to the argument list of a :data:`typing.Union`. Note that any valid converter discussed above can be passed in to the argument list of a :data:`typing.Union`.
.. admonition:: Slash Command Only
These are not currently supported by the Discord API and will be sent as type 3 (STRING)
typing.Optional typing.Optional
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
@ -680,6 +704,11 @@ In order to customise the flag syntax we also have a few options that can be pas
a command line parser. The syntax is mainly inspired by Discord's search bar input and as a result a command line parser. The syntax is mainly inspired by Discord's search bar input and as a result
all flags need a corresponding value. all flags need a corresponding value.
.. admonition:: Slash Command Only
As these are built very similar to slash command options, they are converted into options and parsed
back into flags when the slash command is executed.
The flag converter is similar to regular commands and allows you to use most types of converters The flag converter is similar to regular commands and allows you to use most types of converters
(with the exception of :class:`~ext.commands.Greedy`) as the type annotation. Some extra support is added for specific (with the exception of :class:`~ext.commands.Greedy`) as the type annotation. Some extra support is added for specific
annotations as described below. annotations as described below.

View File

@ -15,4 +15,5 @@ extension library that handles this for you.
commands commands
cogs cogs
extensions extensions
slash-commands
api api

View File

@ -0,0 +1,23 @@
.. currentmodule:: discord
.. _ext_commands_slash_commands:
Slash Commands
==============
Slash Commands are currently supported in enhanced-discord.py using a system on top of ext.commands.
This system is very simple to use, and can be enabled via :attr:`.Bot.slash_commands` globally,
or only for specific commands via :attr:`.Command.slash_command`.
There is also the parameter ``slash_command_guilds`` which can be passed to either :class:`.Bot` or the command
decorator in order to only upload the commands as guild commands to these specific guild IDs, however this
should only be used for testing or small (<10 guilds) bots.
If you want to add option descriptions to your commands, you should use :class:`.Option`
For troubleshooting, see the :ref:`FAQ <ext_commands_slash_command_troubleshooting>`
.. admonition:: Slash Command Only
For parts of the docs specific to slash commands, look for this box!

View File

@ -9,60 +9,78 @@ import inspect
import os import os
import re import re
class attributetable(nodes.General, nodes.Element): class attributetable(nodes.General, nodes.Element):
pass pass
class attributetablecolumn(nodes.General, nodes.Element): class attributetablecolumn(nodes.General, nodes.Element):
pass pass
class attributetabletitle(nodes.TextElement): class attributetabletitle(nodes.TextElement):
pass pass
class attributetableplaceholder(nodes.General, nodes.Element): class attributetableplaceholder(nodes.General, nodes.Element):
pass pass
class attributetablebadge(nodes.TextElement): class attributetablebadge(nodes.TextElement):
pass pass
class attributetable_item(nodes.Part, nodes.Element): class attributetable_item(nodes.Part, nodes.Element):
pass pass
def visit_attributetable_node(self, node): def visit_attributetable_node(self, node):
class_ = node["python-class"] class_ = node["python-class"]
self.body.append(f'<div class="py-attribute-table" data-move-to-id="{class_}">') self.body.append(f'<div class="py-attribute-table" data-move-to-id="{class_}">')
def visit_attributetablecolumn_node(self, node): def visit_attributetablecolumn_node(self, node):
self.body.append(self.starttag(node, 'div', CLASS='py-attribute-table-column')) self.body.append(self.starttag(node, "div", CLASS="py-attribute-table-column"))
def visit_attributetabletitle_node(self, node): def visit_attributetabletitle_node(self, node):
self.body.append(self.starttag(node, 'span')) self.body.append(self.starttag(node, "span"))
def visit_attributetablebadge_node(self, node): def visit_attributetablebadge_node(self, node):
attributes = { attributes = {
'class': 'py-attribute-table-badge', "class": "py-attribute-table-badge",
'title': node['badge-type'], "title": node["badge-type"],
} }
self.body.append(self.starttag(node, 'span', **attributes)) self.body.append(self.starttag(node, "span", **attributes))
def visit_attributetable_item_node(self, node): def visit_attributetable_item_node(self, node):
self.body.append(self.starttag(node, 'li', CLASS='py-attribute-table-entry')) self.body.append(self.starttag(node, "li", CLASS="py-attribute-table-entry"))
def depart_attributetable_node(self, node): def depart_attributetable_node(self, node):
self.body.append('</div>') self.body.append("</div>")
def depart_attributetablecolumn_node(self, node): def depart_attributetablecolumn_node(self, node):
self.body.append('</div>') self.body.append("</div>")
def depart_attributetabletitle_node(self, node): def depart_attributetabletitle_node(self, node):
self.body.append('</span>') self.body.append("</span>")
def depart_attributetablebadge_node(self, node): def depart_attributetablebadge_node(self, node):
self.body.append('</span>') self.body.append("</span>")
def depart_attributetable_item_node(self, node): def depart_attributetable_item_node(self, node):
self.body.append('</li>') self.body.append("</li>")
_name_parser_regex = re.compile(r"(?P<module>[\w.]+\.)?(?P<name>\w+)")
_name_parser_regex = re.compile(r'(?P<module>[\w.]+\.)?(?P<name>\w+)')
class PyAttributeTable(SphinxDirective): class PyAttributeTable(SphinxDirective):
has_content = False has_content = False
@ -74,13 +92,13 @@ class PyAttributeTable(SphinxDirective):
def parse_name(self, content): def parse_name(self, content):
path, name = _name_parser_regex.match(content).groups() path, name = _name_parser_regex.match(content).groups()
if path: if path:
modulename = path.rstrip('.') modulename = path.rstrip(".")
else: else:
modulename = self.env.temp_data.get('autodoc:module') modulename = self.env.temp_data.get("autodoc:module")
if not modulename: if not modulename:
modulename = self.env.ref_context.get('py:module') modulename = self.env.ref_context.get("py:module")
if modulename is None: if modulename is None:
raise RuntimeError('modulename somehow None for %s in %s.' % (content, self.env.docname)) raise RuntimeError("modulename somehow None for %s in %s." % (content, self.env.docname))
return modulename, name return modulename, name
@ -112,29 +130,33 @@ class PyAttributeTable(SphinxDirective):
replaced. replaced.
""" """
content = self.arguments[0].strip() content = self.arguments[0].strip()
node = attributetableplaceholder('') node = attributetableplaceholder("")
modulename, name = self.parse_name(content) modulename, name = self.parse_name(content)
node['python-doc'] = self.env.docname node["python-doc"] = self.env.docname
node['python-module'] = modulename node["python-module"] = modulename
node['python-class'] = name node["python-class"] = name
node['python-full-name'] = f'{modulename}.{name}' node["python-full-name"] = f"{modulename}.{name}"
return [node] return [node]
def build_lookup_table(env): def build_lookup_table(env):
# Given an environment, load up a lookup table of # Given an environment, load up a lookup table of
# full-class-name: objects # full-class-name: objects
result = {} result = {}
domain = env.domains['py'] domain = env.domains["py"]
ignored = { ignored = {
'data', 'exception', 'module', 'class', "data",
"exception",
"module",
"class",
} }
for (fullname, _, objtype, docname, _, _) in domain.get_objects(): for (fullname, _, objtype, docname, _, _) in domain.get_objects():
if objtype in ignored: if objtype in ignored:
continue continue
classname, _, child = fullname.rpartition('.') classname, _, child = fullname.rpartition(".")
try: try:
result[classname].append(child) result[classname].append(child)
except KeyError: except KeyError:
@ -143,36 +165,40 @@ def build_lookup_table(env):
return result return result
TableElement = namedtuple('TableElement', 'fullname label badge') TableElement = namedtuple("TableElement", "fullname label badge")
def process_attributetable(app, doctree, fromdocname): def process_attributetable(app, doctree, fromdocname):
env = app.builder.env env = app.builder.env
lookup = build_lookup_table(env) lookup = build_lookup_table(env)
for node in doctree.traverse(attributetableplaceholder): for node in doctree.traverse(attributetableplaceholder):
modulename, classname, fullname = node['python-module'], node['python-class'], node['python-full-name'] modulename, classname, fullname = node["python-module"], node["python-class"], node["python-full-name"]
groups = get_class_results(lookup, modulename, classname, fullname) groups = get_class_results(lookup, modulename, classname, fullname)
table = attributetable('') table = attributetable("")
for label, subitems in groups.items(): for label, subitems in groups.items():
if not subitems: if not subitems:
continue continue
table.append(class_results_to_node(label, sorted(subitems, key=lambda c: c.label))) table.append(class_results_to_node(label, sorted(subitems, key=lambda c: c.label)))
table['python-class'] = fullname table["python-class"] = fullname
if not table: if not table:
node.replace_self([]) node.replace_self([])
else: else:
node.replace_self([table]) node.replace_self([table])
def get_class_results(lookup, modulename, name, fullname): def get_class_results(lookup, modulename, name, fullname):
module = importlib.import_module(modulename) module = importlib.import_module(modulename)
cls = getattr(module, name) cls = getattr(module, name)
groups = OrderedDict([ groups = OrderedDict(
(_('Attributes'), []), [
(_('Methods'), []), (_("Attributes"), []),
]) (_("Methods"), []),
]
)
try: try:
members = lookup[fullname] members = lookup[fullname]
@ -180,8 +206,8 @@ def get_class_results(lookup, modulename, name, fullname):
return groups return groups
for attr in members: for attr in members:
attrlookup = f'{fullname}.{attr}' attrlookup = f"{fullname}.{attr}"
key = _('Attributes') key = _("Attributes")
badge = None badge = None
label = attr label = attr
value = None value = None
@ -192,53 +218,54 @@ def get_class_results(lookup, modulename, name, fullname):
break break
if value is not None: if value is not None:
doc = value.__doc__ or '' doc = value.__doc__ or ""
if inspect.iscoroutinefunction(value) or doc.startswith('|coro|'): if inspect.iscoroutinefunction(value) or doc.startswith("|coro|"):
key = _('Methods') key = _("Methods")
badge = attributetablebadge('async', 'async') badge = attributetablebadge("async", "async")
badge['badge-type'] = _('coroutine') badge["badge-type"] = _("coroutine")
elif isinstance(value, classmethod): elif isinstance(value, classmethod):
key = _('Methods') key = _("Methods")
label = f'{name}.{attr}' label = f"{name}.{attr}"
badge = attributetablebadge('cls', 'cls') badge = attributetablebadge("cls", "cls")
badge['badge-type'] = _('classmethod') badge["badge-type"] = _("classmethod")
elif inspect.isfunction(value): elif inspect.isfunction(value):
if doc.startswith(('A decorator', 'A shortcut decorator')): if doc.startswith(("A decorator", "A shortcut decorator")):
# finicky but surprisingly consistent # finicky but surprisingly consistent
badge = attributetablebadge('@', '@') badge = attributetablebadge("@", "@")
badge['badge-type'] = _('decorator') badge["badge-type"] = _("decorator")
key = _('Methods') key = _("Methods")
else: else:
key = _('Methods') key = _("Methods")
badge = attributetablebadge('def', 'def') badge = attributetablebadge("def", "def")
badge['badge-type'] = _('method') badge["badge-type"] = _("method")
groups[key].append(TableElement(fullname=attrlookup, label=label, badge=badge)) groups[key].append(TableElement(fullname=attrlookup, label=label, badge=badge))
return groups return groups
def class_results_to_node(key, elements): def class_results_to_node(key, elements):
title = attributetabletitle(key, key) title = attributetabletitle(key, key)
ul = nodes.bullet_list('') ul = nodes.bullet_list("")
for element in elements: for element in elements:
ref = nodes.reference('', '', internal=True, ref = nodes.reference(
refuri='#' + element.fullname, "", "", internal=True, refuri="#" + element.fullname, anchorname="", *[nodes.Text(element.label)]
anchorname='', )
*[nodes.Text(element.label)]) para = addnodes.compact_paragraph("", "", ref)
para = addnodes.compact_paragraph('', '', ref)
if element.badge is not None: if element.badge is not None:
ul.append(attributetable_item('', element.badge, para)) ul.append(attributetable_item("", element.badge, para))
else: else:
ul.append(attributetable_item('', para)) ul.append(attributetable_item("", para))
return attributetablecolumn("", title, ul)
return attributetablecolumn('', title, ul)
def setup(app): def setup(app):
app.add_directive('attributetable', PyAttributeTable) app.add_directive("attributetable", PyAttributeTable)
app.add_node(attributetable, html=(visit_attributetable_node, depart_attributetable_node)) app.add_node(attributetable, html=(visit_attributetable_node, depart_attributetable_node))
app.add_node(attributetablecolumn, html=(visit_attributetablecolumn_node, depart_attributetablecolumn_node)) app.add_node(attributetablecolumn, html=(visit_attributetablecolumn_node, depart_attributetablecolumn_node))
app.add_node(attributetabletitle, html=(visit_attributetabletitle_node, depart_attributetabletitle_node)) app.add_node(attributetabletitle, html=(visit_attributetabletitle_node, depart_attributetabletitle_node))
app.add_node(attributetablebadge, html=(visit_attributetablebadge_node, depart_attributetablebadge_node)) app.add_node(attributetablebadge, html=(visit_attributetablebadge_node, depart_attributetablebadge_node))
app.add_node(attributetable_item, html=(visit_attributetable_item_node, depart_attributetable_item_node)) app.add_node(attributetable_item, html=(visit_attributetable_item_node, depart_attributetable_item_node))
app.add_node(attributetableplaceholder) app.add_node(attributetableplaceholder)
app.connect('doctree-resolved', process_attributetable) app.connect("doctree-resolved", process_attributetable)

View File

@ -2,15 +2,15 @@ from sphinx.builders.html import StandaloneHTMLBuilder
from sphinx.environment.adapters.indexentries import IndexEntries from sphinx.environment.adapters.indexentries import IndexEntries
from sphinx.writers.html5 import HTML5Translator from sphinx.writers.html5 import HTML5Translator
class DPYHTML5Translator(HTML5Translator): class DPYHTML5Translator(HTML5Translator):
def visit_section(self, node): def visit_section(self, node):
self.section_level += 1 self.section_level += 1
self.body.append( self.body.append(self.starttag(node, "section"))
self.starttag(node, 'section'))
def depart_section(self, node): def depart_section(self, node):
self.section_level -= 1 self.section_level -= 1
self.body.append('</section>\n') self.body.append("</section>\n")
def visit_table(self, node): def visit_table(self, node):
self.body.append('<div class="table-wrapper">') self.body.append('<div class="table-wrapper">')
@ -18,7 +18,8 @@ class DPYHTML5Translator(HTML5Translator):
def depart_table(self, node): def depart_table(self, node):
super().depart_table(node) super().depart_table(node)
self.body.append('</div>') self.body.append("</div>")
class DPYStandaloneHTMLBuilder(StandaloneHTMLBuilder): class DPYStandaloneHTMLBuilder(StandaloneHTMLBuilder):
# This is mostly copy pasted from Sphinx. # This is mostly copy pasted from Sphinx.
@ -28,50 +29,48 @@ class DPYStandaloneHTMLBuilder(StandaloneHTMLBuilder):
genindex = IndexEntries(self.env).create_index(self, group_entries=False) genindex = IndexEntries(self.env).create_index(self, group_entries=False)
indexcounts = [] indexcounts = []
for _k, entries in genindex: for _k, entries in genindex:
indexcounts.append(sum(1 + len(subitems) indexcounts.append(sum(1 + len(subitems) for _, (_, subitems, _) in entries))
for _, (_, subitems, _) in entries))
genindexcontext = { genindexcontext = {
'genindexentries': genindex, "genindexentries": genindex,
'genindexcounts': indexcounts, "genindexcounts": indexcounts,
'split_index': self.config.html_split_index, "split_index": self.config.html_split_index,
} }
if self.config.html_split_index: if self.config.html_split_index:
self.handle_page('genindex', genindexcontext, self.handle_page("genindex", genindexcontext, "genindex-split.html")
'genindex-split.html') self.handle_page("genindex-all", genindexcontext, "genindex.html")
self.handle_page('genindex-all', genindexcontext,
'genindex.html')
for (key, entries), count in zip(genindex, indexcounts): for (key, entries), count in zip(genindex, indexcounts):
ctx = {'key': key, 'entries': entries, 'count': count, ctx = {"key": key, "entries": entries, "count": count, "genindexentries": genindex}
'genindexentries': genindex} self.handle_page("genindex-" + key, ctx, "genindex-single.html")
self.handle_page('genindex-' + key, ctx,
'genindex-single.html')
else: else:
self.handle_page('genindex', genindexcontext, 'genindex.html') self.handle_page("genindex", genindexcontext, "genindex.html")
def add_custom_jinja2(app): def add_custom_jinja2(app):
env = app.builder.templates.environment env = app.builder.templates.environment
env.tests['prefixedwith'] = str.startswith env.tests["prefixedwith"] = str.startswith
env.tests['suffixedwith'] = str.endswith env.tests["suffixedwith"] = str.endswith
def add_builders(app): def add_builders(app):
"""This is necessary because RTD injects their own for some reason.""" """This is necessary because RTD injects their own for some reason."""
app.set_translator('html', DPYHTML5Translator, override=True) app.set_translator("html", DPYHTML5Translator, override=True)
app.add_builder(DPYStandaloneHTMLBuilder, override=True) app.add_builder(DPYStandaloneHTMLBuilder, override=True)
try: try:
original = app.registry.builders['readthedocs'] original = app.registry.builders["readthedocs"]
except KeyError: except KeyError:
pass pass
else: else:
injected_mro = tuple(base if base is not StandaloneHTMLBuilder else DPYStandaloneHTMLBuilder injected_mro = tuple(
for base in original.mro()[1:]) base if base is not StandaloneHTMLBuilder else DPYStandaloneHTMLBuilder for base in original.mro()[1:]
new_builder = type(original.__name__, injected_mro, {'name': 'readthedocs'}) )
app.set_translator('readthedocs', DPYHTML5Translator, override=True) new_builder = type(original.__name__, injected_mro, {"name": "readthedocs"})
app.set_translator("readthedocs", DPYHTML5Translator, override=True)
app.add_builder(new_builder, override=True) app.add_builder(new_builder, override=True)
def setup(app): def setup(app):
add_builders(app) add_builders(app)
app.connect('builder-inited', add_custom_jinja2) app.connect("builder-inited", add_custom_jinja2)

View File

@ -3,32 +3,39 @@ from docutils.parsers.rst import states, directives
from docutils.parsers.rst.roles import set_classes from docutils.parsers.rst.roles import set_classes
from docutils import nodes from docutils import nodes
class details(nodes.General, nodes.Element): class details(nodes.General, nodes.Element):
pass pass
class summary(nodes.General, nodes.Element): class summary(nodes.General, nodes.Element):
pass pass
def visit_details_node(self, node): def visit_details_node(self, node):
self.body.append(self.starttag(node, 'details', CLASS=node.attributes.get('class', ''))) self.body.append(self.starttag(node, "details", CLASS=node.attributes.get("class", "")))
def visit_summary_node(self, node): def visit_summary_node(self, node):
self.body.append(self.starttag(node, 'summary', CLASS=node.attributes.get('summary-class', ''))) self.body.append(self.starttag(node, "summary", CLASS=node.attributes.get("summary-class", "")))
self.body.append(node.rawsource) self.body.append(node.rawsource)
def depart_details_node(self, node): def depart_details_node(self, node):
self.body.append('</details>\n') self.body.append("</details>\n")
def depart_summary_node(self, node): def depart_summary_node(self, node):
self.body.append('</summary>') self.body.append("</summary>")
class DetailsDirective(Directive): class DetailsDirective(Directive):
final_argument_whitespace = True final_argument_whitespace = True
optional_arguments = 1 optional_arguments = 1
option_spec = { option_spec = {
'class': directives.class_option, "class": directives.class_option,
'summary-class': directives.class_option, "summary-class": directives.class_option,
} }
has_content = True has_content = True
@ -37,7 +44,7 @@ class DetailsDirective(Directive):
set_classes(self.options) set_classes(self.options)
self.assert_has_content() self.assert_has_content()
text = '\n'.join(self.content) text = "\n".join(self.content)
node = details(text, **self.options) node = details(text, **self.options)
if self.arguments: if self.arguments:
@ -48,8 +55,8 @@ class DetailsDirective(Directive):
self.state.nested_parse(self.content, self.content_offset, node) self.state.nested_parse(self.content, self.content_offset, node)
return [node] return [node]
def setup(app): def setup(app):
app.add_node(details, html=(visit_details_node, depart_details_node)) app.add_node(details, html=(visit_details_node, depart_details_node))
app.add_node(summary, html=(visit_summary_node, depart_summary_node)) app.add_node(summary, html=(visit_summary_node, depart_summary_node))
app.add_directive('details', DetailsDirective) app.add_directive("details", DetailsDirective)

View File

@ -4,24 +4,29 @@ from docutils.parsers.rst.roles import set_classes
from docutils import nodes from docutils import nodes
from sphinx.locale import _ from sphinx.locale import _
class exception_hierarchy(nodes.General, nodes.Element): class exception_hierarchy(nodes.General, nodes.Element):
pass pass
def visit_exception_hierarchy_node(self, node): def visit_exception_hierarchy_node(self, node):
self.body.append(self.starttag(node, 'div', CLASS='exception-hierarchy-content')) self.body.append(self.starttag(node, "div", CLASS="exception-hierarchy-content"))
def depart_exception_hierarchy_node(self, node): def depart_exception_hierarchy_node(self, node):
self.body.append('</div>\n') self.body.append("</div>\n")
class ExceptionHierarchyDirective(Directive): class ExceptionHierarchyDirective(Directive):
has_content = True has_content = True
def run(self): def run(self):
self.assert_has_content() self.assert_has_content()
node = exception_hierarchy('\n'.join(self.content)) node = exception_hierarchy("\n".join(self.content))
self.state.nested_parse(self.content, self.content_offset, node) self.state.nested_parse(self.content, self.content_offset, node)
return [node] return [node]
def setup(app): def setup(app):
app.add_node(exception_hierarchy, html=(visit_exception_hierarchy_node, depart_exception_hierarchy_node)) app.add_node(exception_hierarchy, html=(visit_exception_hierarchy_node, depart_exception_hierarchy_node))
app.add_directive('exception_hierarchy', ExceptionHierarchyDirective) app.add_directive("exception_hierarchy", ExceptionHierarchyDirective)

View File

@ -5,18 +5,17 @@ from sphinx.util import logging as sphinx_logging
class NitpickFileIgnorer(logging.Filter): class NitpickFileIgnorer(logging.Filter):
def __init__(self, app: Sphinx) -> None: def __init__(self, app: Sphinx) -> None:
self.app = app self.app = app
super().__init__() super().__init__()
def filter(self, record: sphinx_logging.SphinxLogRecord) -> bool: def filter(self, record: sphinx_logging.SphinxLogRecord) -> bool:
if getattr(record, 'type', None) == 'ref': if getattr(record, "type", None) == "ref":
return record.location.get('refdoc') not in self.app.config.nitpick_ignore_files return record.location.get("refdoc") not in self.app.config.nitpick_ignore_files
return True return True
def setup(app: Sphinx): def setup(app: Sphinx):
app.add_config_value('nitpick_ignore_files', [], '') app.add_config_value("nitpick_ignore_files", [], "")
f = NitpickFileIgnorer(app) f = NitpickFileIgnorer(app)
sphinx_logging.getLogger('sphinx.transforms.post_transforms').logger.addFilter(f) sphinx_logging.getLogger("sphinx.transforms.post_transforms").logger.addFilter(f)

View File

@ -16,13 +16,7 @@ from sphinx.util.typing import RoleFunction
def make_link_role(resource_links: Dict[str, str]) -> RoleFunction: def make_link_role(resource_links: Dict[str, str]) -> RoleFunction:
def role( def role(
typ: str, typ: str, rawtext: str, text: str, lineno: int, inliner: Inliner, options: Dict = {}, content: List[str] = []
rawtext: str,
text: str,
lineno: int,
inliner: Inliner,
options: Dict = {},
content: List[str] = []
) -> Tuple[List[Node], List[system_message]]: ) -> Tuple[List[Node], List[system_message]]:
text = utils.unescape(text) text = utils.unescape(text)
@ -32,13 +26,15 @@ def make_link_role(resource_links: Dict[str, str]) -> RoleFunction:
title = full_url title = full_url
pnode = nodes.reference(title, title, internal=False, refuri=full_url) pnode = nodes.reference(title, title, internal=False, refuri=full_url)
return [pnode], [] return [pnode], []
return role return role
def add_link_role(app: Sphinx) -> None: def add_link_role(app: Sphinx) -> None:
app.add_role('resource', make_link_role(app.config.resource_links)) app.add_role("resource", make_link_role(app.config.resource_links))
def setup(app: Sphinx) -> Dict[str, Any]: def setup(app: Sphinx) -> Dict[str, Any]:
app.add_config_value('resource_links', {}, 'env') app.add_config_value("resource_links", {}, "env")
app.connect('builder-inited', add_link_role) app.connect("builder-inited", add_link_role)
return {'version': sphinx.__display_version__, 'parallel_read_safe': True} return {"version": sphinx.__display_version__, "parallel_read_safe": True}

View File

@ -410,3 +410,34 @@ Example: ::
await ctx.send(f'Pushing to {remote} {branch}') await ctx.send(f'Pushing to {remote} {branch}')
This could then be used as ``?git push origin master``. This could then be used as ``?git push origin master``.
How do I make slash commands?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
See :doc:`/ext/commands/slash-commands`
My slash commands aren't showing up!
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. _ext_commands_slash_command_troubleshooting:
You need to invite your bot with the ``application.commands`` scope on each guild and
you need the :attr:`Permissions.use_slash_commands` permission in order to see slash commands.
.. image:: /images/discord_oauth2_slash_scope.png
:alt: The scopes checkbox with "bot" and "applications.commands" ticked.
Global slash commands (created by not specifying :attr:`~ext.commands.Bot.slash_command_guilds`) will also take up an
hour to refresh on discord's end, so it is recommended to set :attr:`~ext.commands.Bot.slash_command_guilds` for development.
If none of this works, make sure you are actually running enhanced-discord.py by doing ``print(bot.slash_commands)``
My bot won't start after enabling slash commands!
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This means some of your command metadata is invalid for slash commands.
Make sure your command names and option names are lowercase, and they have to match the regex ``^[\w-]{1,32}$``
If you cannot figure out the problem, you should disable slash commands globally (:attr:`~ext.commands.Bot.slash_commands`\=False)
then go through commands, enabling them specifically with :attr:`~.commands.Command.slash_command`\=True until it
errors, then you can debug the problem with that command specifically.

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Some files were not shown because too many files have changed in this diff Show More