Compare commits

..

3 Commits

Author SHA1 Message Date
Josh
a7c95966ac Fix issue when tasks loop used in classes. 2021-06-28 15:46:36 +10:00
Josh
faa3e84cfb
Remove type-ignore related to Pyright issue. 2021-06-11 14:55:27 +10:00
Josh
3a71d3be5f Use ParamSpec in ext-tasks 2021-06-10 22:54:59 +10:00
154 changed files with 43715 additions and 15301 deletions

View File

@ -1,7 +1,5 @@
## Contributing to discord.py
Credits to the `original lib` by Rapptz <https://github.com/Rapptz/discord.py>
First off, thanks for taking the time to contribute. It makes the library substantially better. :+1:
The following is a set of guidelines for contributing to the repository. These are guidelines, not hard rules.
@ -10,9 +8,9 @@ The following is a set of guidelines for contributing to the repository. These a
Generally speaking questions are better suited in our resources below.
- The official support server: https://discord.gg/TvqYBrGXEm
- The official support server: https://discord.gg/r3sSKJJ
- The Discord API server under #python_discord-py: https://discord.gg/discord-api
- [The FAQ in the documentation](https://enhanced-dpy.readthedocs.io/en/latest/faq.html)
- [The FAQ in the documentation](https://discordpy.readthedocs.io/en/latest/faq.html)
- [StackOverflow's `discord.py` tag](https://stackoverflow.com/questions/tagged/discord.py)
Please try your best not to ask questions in our issue tracker. Most of them don't belong there unless they provide value to a larger audience.
@ -34,13 +32,13 @@ If the bug report is missing this information then it'll take us longer to fix t
## Submitting a Pull Request
Submitting a pull request is fairly simple, just make sure it focuses on a single aspect and doesn't manage to have scope creep, and it's probably good to go. It would be incredibly lovely if the style is consistent to that found in the project. This project follows the black code format, with a line length limit of `120`
Submitting a pull request is fairly simple, just make sure it focuses on a single aspect and doesn't manage to have scope creep and it's probably good to go. It would be incredibly lovely if the style is consistent to that found in the project. This project follows PEP-8 guidelines (mostly) with a column limit of 125.
### Git Commit Guidelines
- Use present tense (e.g. "Add feature" not "Added feature")
- Limit all lines to 120 characters or fewer.
- Reference issues or pull requests outside the first line.
- Limit all lines to 72 characters or less.
- Reference issues or pull requests outside of the first line.
- Please use the shorthand `#123` and not the full URL.
- Commits regarding the commands extension must be prefixed with `[commands]`

View File

@ -6,7 +6,7 @@ body:
attributes:
value: >
Thanks for taking the time to fill out a bug.
If you want real-time support, consider joining our Discord at https://discord.gg/TvqYBrGXEm instead.
If you want real-time support, consider joining our Discord at https://discord.gg/r3sSKJJ instead.
Please note that this form is for bugs only!
- type: input

View File

@ -5,4 +5,4 @@ contact_links:
url: https://github.com/Rapptz/discord.py/discussions
- name: Discord Server
about: Use our official Discord server to ask for help and questions as well.
url: https://discord.gg/TvqYBrGXEm
url: https://discord.gg/r3sSKJJ

View File

@ -1,5 +1,3 @@
<!-- Pull requests that do not fill this information in will likely be closed -->
## Summary
<!-- What is this pull request for? Does it fix any issues? -->

View File

@ -1,38 +0,0 @@
name: CI
on: [push, pull_request]
jobs:
pyright:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Setup node.js (for pyright)
uses: actions/setup-node@v1
with:
node-version: "14"
- name: Run type checking
run: |
npm install -g pyright
pip install .
pyright --lib --verifytypes discord --ignoreexternal
black:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Run linter
uses: psf/black@stable
with:
options: "--line-length 120 --check"
src: "./discord"

View File

@ -1 +0,0 @@

View File

@ -2,4 +2,3 @@ include README.rst
include LICENSE
include requirements.txt
include discord/bin/*.dll
include discord/py.typed

114
README.ja.rst Normal file
View File

@ -0,0 +1,114 @@
discord.py
==========
.. image:: https://discord.com/api/guilds/336642139381301249/embed.png
:target: https://discord.gg/nXzj3dg
:alt: Discordサーバーの招待
.. image:: https://img.shields.io/pypi/v/discord.py.svg
:target: https://pypi.python.org/pypi/discord.py
:alt: PyPIのバージョン情報
.. image:: https://img.shields.io/pypi/pyversions/discord.py.svg
:target: https://pypi.python.org/pypi/discord.py
:alt: PyPIのサポートしているPythonのバージョン
discord.py は機能豊富かつモダンで使いやすい、非同期処理にも対応したDiscord用のAPIラッパーです。
主な特徴
-------------
- ``async````await`` を使ったモダンなPythonらしいAPI。
- 適切なレート制限処理
- Discord APIによってサポートされているものを100カバー。
- メモリと速度の両方を最適化。
インストール
-------------
**Python 3.5.3 以降のバージョンが必須です**
完全な音声サポートなしでライブラリをインストールする場合は次のコマンドを実行してください:
.. code:: sh
# Linux/OS X
python3 -m pip install -U discord.py
# Windows
py -3 -m pip install -U discord.py
音声サポートが必要なら、次のコマンドを実行しましょう:
.. code:: sh
# Linux/OS X
python3 -m pip install -U discord.py[voice]
# Windows
py -3 -m pip install -U discord.py[voice]
開発版をインストールしたいのならば、次の手順に従ってください:
.. code:: sh
$ git clone https://github.com/Rapptz/discord.py
$ cd discord.py
$ python3 -m pip install -U .[voice]
オプションパッケージ
~~~~~~~~~~~~~~~~~~~~~~
* PyNaCl (音声サポート用)
Linuxで音声サポートを導入するには、前述のコマンドを実行する前にお気に入りのパッケージマネージャー(例えば ``apt````dnf`` など)を使って以下のパッケージをインストールする必要があります:
* libffi-dev (システムによっては ``libffi-devel``)
* python-dev (例えばPython 3.6用の ``python3.6-dev``)
簡単な例
--------------
.. code:: py
import discord
class MyClient(discord.Client):
async def on_ready(self):
print('Logged on as', self.user)
async def on_message(self, message):
# don't respond to ourselves
if message.author == self.user:
return
if message.content == 'ping':
await message.channel.send('pong')
client = MyClient()
client.run('token')
Botの例
~~~~~~~~~~~~~
.. code:: py
import discord
from discord.ext import commands
bot = commands.Bot(command_prefix='>')
@bot.command()
async def ping(ctx):
await ctx.send('pong')
bot.run('token')
examplesディレクトリに更に多くのサンプルがあります。
リンク
------
- `ドキュメント <https://discordpy.readthedocs.io/ja/latest/index.html>`_
- `公式Discordサーバー <https://discord.gg/nXzj3dg>`_
- `Discord API <https://discord.gg/discord-api>`_

View File

@ -1,53 +1,58 @@
enhanced-discord.py
===================
discord.py
==========
.. image:: https://discord.com/api/guilds/514232441498763279/embed.png
:target: https://discord.gg/TvqYBrGXEm
.. image:: https://discord.com/api/guilds/336642139381301249/embed.png
:target: https://discord.gg/r3sSKJJ
:alt: Discord server invite
.. image:: https://img.shields.io/pypi/v/enhanced-dpy.svg
:target: https://pypi.python.org/pypi/enhanced-dpy
.. image:: https://img.shields.io/pypi/v/discord.py.svg
:target: https://pypi.python.org/pypi/discord.py
:alt: PyPI version info
.. image:: https://img.shields.io/pypi/pyversions/enhanced-dpy.svg
:target: https://pypi.python.org/pypi/enhanced-dpy
.. image:: https://img.shields.io/pypi/pyversions/discord.py.svg
:target: https://pypi.python.org/pypi/discord.py
:alt: PyPI supported Python versions
A modern, maintained, easy to use, feature-rich, and async ready API wrapper for Discord written in Python.
The Future of enhanced-discord.py
--------------------------
Enhanced discord.py is a fork of Rapptz's discord.py, that went unmaintained (`gist <https://gist.github.com/Rapptz/4a2f62751b9600a31a0d3c78100287f1>`_)
An overview of added features is available on the `custom features page <https://enhanced-dpy.readthedocs.io/en/latest/index.html#custom-features>`_.
A modern, easy to use, feature-rich, and async ready API wrapper for Discord written in Python.
Key Features
-------------
- Modern Pythonic API using ``async`` and ``await``.
- Proper rate limit handling.
- 100% coverage of the supported Discord API.
- Optimised in both speed and memory.
Installing
----------
**Python 3.8 or higher is required**
**Python 3.5.3 or higher is required**
To install the library without full voice support, you can just run the following command:
.. code:: sh
# Linux/macOS
python3 -m pip install -U enhanced-dpy
python3 -m pip install -U discord.py
# Windows
py -3 -m pip install -U enhanced-dpy
py -3 -m pip install -U discord.py
Otherwise to get voice support you should run the following command:
.. code:: sh
# Linux/macOS
python3 -m pip install -U "discord.py[voice]"
# Windows
py -3 -m pip install -U discord.py[voice]
To install the development version, do the following:
.. code:: sh
$ git clone https://github.com/iDevision/enhanced-discord.py
$ cd enhanced-discord.py
$ git clone https://github.com/Rapptz/discord.py
$ cd discord.py
$ python3 -m pip install -U .[voice]
@ -104,6 +109,6 @@ You can find more examples in the examples directory.
Links
------
- `Documentation <https://enhanced-dpy.readthedocs.io/en/latest/index.html>`_
- `Official Discord Server <https://discord.gg/TvqYBrGXEm>`_
- `Documentation <https://discordpy.readthedocs.io/en/latest/index.html>`_
- `Official Discord Server <https://discord.gg/r3sSKJJ>`_
- `Discord API <https://discord.gg/discord-api>`_

View File

@ -9,16 +9,16 @@ A basic wrapper for the Discord API.
"""
__title__ = "discord"
__author__ = "Rapptz"
__license__ = "MIT"
__copyright__ = "Copyright 2015-present Rapptz"
__version__ = "2.0.0a"
__title__ = 'discord'
__author__ = 'Rapptz'
__license__ = 'MIT'
__copyright__ = 'Copyright 2015-present Rapptz'
__version__ = '2.0.0a'
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
from collections import namedtuple
import logging
from typing import NamedTuple, Literal
from .client import *
from .appinfo import *
@ -60,15 +60,8 @@ from .interactions import *
from .components import *
from .threads import *
VersionInfo = namedtuple('VersionInfo', 'major minor micro releaselevel serial')
class VersionInfo(NamedTuple):
major: int
minor: int
micro: int
releaselevel: Literal["alpha", "beta", "candidate", "final"]
serial: int
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel="alpha", serial=0)
version_info = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
logging.getLogger(__name__).addHandler(logging.NullHandler())

View File

@ -31,30 +31,27 @@ import pkg_resources
import aiohttp
import platform
def show_version():
entries = []
entries.append("- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(sys.version_info))
entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info))
version_info = discord.version_info
entries.append("- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info))
if version_info.releaselevel != "final":
pkg = pkg_resources.get_distribution("discord.py")
entries.append('- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(version_info))
if version_info.releaselevel != 'final':
pkg = pkg_resources.get_distribution('discord.py')
if pkg:
entries.append(f" - discord.py pkg_resources: v{pkg.version}")
entries.append(f' - discord.py pkg_resources: v{pkg.version}')
entries.append(f"- aiohttp v{aiohttp.__version__}")
entries.append(f'- aiohttp v{aiohttp.__version__}')
uname = platform.uname()
entries.append("- system info: {0.system} {0.release} {0.version}".format(uname))
print("\n".join(entries))
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname))
print('\n'.join(entries))
def core(parser, args):
if args.version:
show_version()
_bot_template = """#!/usr/bin/env python3
bot_template = """#!/usr/bin/env python3
from discord.ext import commands
import discord
@ -80,7 +77,7 @@ bot = Bot()
bot.run(config.token)
"""
_gitignore_template = """# Byte-compiled / optimized / DLL files
gitignore_template = """# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
@ -110,7 +107,7 @@ var/
config.py
"""
_cog_template = '''from discord.ext import commands
cog_template = '''from discord.ext import commands
import discord
class {name}(commands.Cog{attrs}):
@ -123,7 +120,7 @@ def setup(bot):
bot.add_cog({name}(bot))
'''
_cog_extras = """
cog_extras = '''
def cog_unload(self):
# clean up logic goes here
pass
@ -152,68 +149,44 @@ _cog_extras = """
# called after a command is called here
pass
"""
'''
# certain file names and directory names are forbidden
# see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx
# although some of this doesn't apply to Linux, we might as well be consistent
_base_table = {
"<": "-",
">": "-",
":": "-",
'"': "-",
'<': '-',
'>': '-',
':': '-',
'"': '-',
# '/': '-', these are fine
# '\\': '-',
"|": "-",
"?": "-",
"*": "-",
'|': '-',
'?': '-',
'*': '-',
}
# NUL (0) and 1-31 are disallowed
_base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table)
translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False):
if isinstance(name, Path):
return name
if sys.platform == "win32":
forbidden = (
"CON",
"PRN",
"AUX",
"NUL",
"COM1",
"COM2",
"COM3",
"COM4",
"COM5",
"COM6",
"COM7",
"COM8",
"COM9",
"LPT1",
"LPT2",
"LPT3",
"LPT4",
"LPT5",
"LPT6",
"LPT7",
"LPT8",
"LPT9",
)
if sys.platform == 'win32':
forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \
'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9')
if len(name) <= 4 and name.upper() in forbidden:
parser.error("invalid directory name given, use a different one")
parser.error('invalid directory name given, use a different one')
name = name.translate(_translation_table)
name = name.translate(translation_table)
if replace_spaces:
name = name.replace(" ", "-")
name = name.replace(' ', '-')
return Path(name)
def newbot(parser, args):
new_directory = to_path(parser, args.directory) / to_path(parser, args.name)
@ -222,114 +195,106 @@ def newbot(parser, args):
try:
new_directory.mkdir(exist_ok=True, parents=True)
except OSError as exc:
parser.error(f"could not create our bot directory ({exc})")
parser.error(f'could not create our bot directory ({exc})')
cogs = new_directory / "cogs"
cogs = new_directory / 'cogs'
try:
cogs.mkdir(exist_ok=True)
init = cogs / "__init__.py"
init = cogs / '__init__.py'
init.touch()
except OSError as exc:
print(f"warning: could not create cogs directory ({exc})")
print(f'warning: could not create cogs directory ({exc})')
try:
with open(str(new_directory / "config.py"), "w", encoding="utf-8") as fp:
with open(str(new_directory / 'config.py'), 'w', encoding='utf-8') as fp:
fp.write('token = "place your token here"\ncogs = []\n')
except OSError as exc:
parser.error(f"could not create config file ({exc})")
parser.error(f'could not create config file ({exc})')
try:
with open(str(new_directory / "bot.py"), "w", encoding="utf-8") as fp:
base = "Bot" if not args.sharded else "AutoShardedBot"
fp.write(_bot_template.format(base=base, prefix=args.prefix))
with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp:
base = 'Bot' if not args.sharded else 'AutoShardedBot'
fp.write(bot_template.format(base=base, prefix=args.prefix))
except OSError as exc:
parser.error(f"could not create bot file ({exc})")
parser.error(f'could not create bot file ({exc})')
if not args.no_git:
try:
with open(str(new_directory / ".gitignore"), "w", encoding="utf-8") as fp:
fp.write(_gitignore_template)
with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp:
fp.write(gitignore_template)
except OSError as exc:
print(f"warning: could not create .gitignore file ({exc})")
print("successfully made bot at", new_directory)
print(f'warning: could not create .gitignore file ({exc})')
print('successfully made bot at', new_directory)
def newcog(parser, args):
cog_dir = to_path(parser, args.directory)
try:
cog_dir.mkdir(exist_ok=True)
except OSError as exc:
print(f"warning: could not create cogs directory ({exc})")
print(f'warning: could not create cogs directory ({exc})')
directory = cog_dir / to_path(parser, args.name)
directory = directory.with_suffix(".py")
directory = directory.with_suffix('.py')
try:
with open(str(directory), "w", encoding="utf-8") as fp:
attrs = ""
extra = _cog_extras if args.full else ""
with open(str(directory), 'w', encoding='utf-8') as fp:
attrs = ''
extra = cog_extras if args.full else ''
if args.class_name:
name = args.class_name
else:
name = str(directory.stem)
if "-" in name or "_" in name:
translation = str.maketrans("-_", " ")
name = name.translate(translation).title().replace(" ", "")
if '-' in name or '_' in name:
translation = str.maketrans('-_', ' ')
name = name.translate(translation).title().replace(' ', '')
else:
name = name.title()
if args.display_name:
attrs += f', name="{args.display_name}"'
if args.hide_commands:
attrs += ", command_attrs=dict(hidden=True)"
fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs))
attrs += ', command_attrs=dict(hidden=True)'
fp.write(cog_template.format(name=name, extra=extra, attrs=attrs))
except OSError as exc:
parser.error(f"could not create cog file ({exc})")
parser.error(f'could not create cog file ({exc})')
else:
print("successfully made cog at", directory)
print('successfully made cog at', directory)
def add_newbot_args(subparser):
parser = subparser.add_parser("newbot", help="creates a command bot project quickly")
parser = subparser.add_parser('newbot', help='creates a command bot project quickly')
parser.set_defaults(func=newbot)
parser.add_argument("name", help="the bot project name")
parser.add_argument("directory", help="the directory to place it in (default: .)", nargs="?", default=Path.cwd())
parser.add_argument("--prefix", help="the bot prefix (default: $)", default="$", metavar="<prefix>")
parser.add_argument("--sharded", help="whether to use AutoShardedBot", action="store_true")
parser.add_argument("--no-git", help="do not create a .gitignore file", action="store_true", dest="no_git")
parser.add_argument('name', help='the bot project name')
parser.add_argument('directory', help='the directory to place it in (default: .)', nargs='?', default=Path.cwd())
parser.add_argument('--prefix', help='the bot prefix (default: $)', default='$', metavar='<prefix>')
parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true')
parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git')
def add_newcog_args(subparser):
parser = subparser.add_parser("newcog", help="creates a new cog template quickly")
parser = subparser.add_parser('newcog', help='creates a new cog template quickly')
parser.set_defaults(func=newcog)
parser.add_argument("name", help="the cog name")
parser.add_argument(
"directory", help="the directory to place it in (default: cogs)", nargs="?", default=Path("cogs")
)
parser.add_argument("--class-name", help="the class name of the cog (default: <name>)", dest="class_name")
parser.add_argument("--display-name", help="the cog name (default: <name>)")
parser.add_argument("--hide-commands", help="whether to hide all commands in the cog", action="store_true")
parser.add_argument("--full", help="add all special methods as well", action="store_true")
parser.add_argument('name', help='the cog name')
parser.add_argument('directory', help='the directory to place it in (default: cogs)', nargs='?', default=Path('cogs'))
parser.add_argument('--class-name', help='the class name of the cog (default: <name>)', dest='class_name')
parser.add_argument('--display-name', help='the cog name (default: <name>)')
parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true')
parser.add_argument('--full', help='add all special methods as well', action='store_true')
def parse_args():
parser = argparse.ArgumentParser(prog="discord", description="Tools for helping with discord.py")
parser.add_argument("-v", "--version", action="store_true", help="shows the library version")
parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py')
parser.add_argument('-v', '--version', action='store_true', help='shows the library version')
parser.set_defaults(func=core)
subparser = parser.add_subparsers(dest="subcommand", title="subcommands")
subparser = parser.add_subparsers(dest='subcommand', title='subcommands')
add_newbot_args(subparser)
add_newcog_args(subparser)
return parser, parser.parse_args()
def main():
parser, args = parse_args()
args.func(parser, args)
if __name__ == "__main__":
if __name__ == '__main__':
main()

View File

@ -26,21 +26,7 @@ from __future__ import annotations
import copy
import asyncio
from typing import (
Any,
Callable,
Dict,
List,
Optional,
TYPE_CHECKING,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
overload,
runtime_checkable,
)
from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, Type, TypeVar, Union, overload, runtime_checkable
from .iterators import HistoryIterator
from .context_managers import Typing
@ -52,24 +38,22 @@ from .role import Role
from .invite import Invite
from .file import File
from .voice_client import VoiceClient, VoiceProtocol
from .sticker import GuildSticker, StickerItem
from . import utils
__all__ = (
"Snowflake",
"User",
"PrivateChannel",
"GuildChannel",
"Messageable",
"Connectable",
'Snowflake',
'User',
'PrivateChannel',
'GuildChannel',
'Messageable',
'Connectable',
)
T = TypeVar("T", bound=VoiceProtocol)
T = TypeVar('T', bound=VoiceProtocol)
if TYPE_CHECKING:
from datetime import datetime
from .client import Client
from .user import ClientUser
from .asset import Asset
from .state import ConnectionState
@ -77,28 +61,18 @@ if TYPE_CHECKING:
from .member import Member
from .channel import CategoryChannel
from .embeds import Embed
from .message import Message, MessageReference, PartialMessage
from .channel import TextChannel, DMChannel, GroupChannel, PartialMessageable
from .threads import Thread
from .message import Message, MessageReference
from .enums import InviteTarget
from .ui.view import View
from .types.channel import (
PermissionOverwrite as PermissionOverwritePayload,
Channel as ChannelPayload,
GuildChannel as GuildChannelPayload,
OverwriteType,
)
PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable]
MessageableChannel = Union[PartialMessageableChannel, GroupChannel]
SnowflakeTime = Union["Snowflake", datetime]
MISSING = utils.MISSING
class _Undefined:
def __repr__(self) -> str:
return "see-below"
def __repr__(self):
return 'see-below'
_undefined: Any = _Undefined()
@ -123,6 +97,10 @@ class Snowflake(Protocol):
__slots__ = ()
id: int
@property
def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC."""
raise NotImplementedError
@runtime_checkable
class User(Snowflake, Protocol):
@ -189,23 +167,23 @@ class PrivateChannel(Snowflake, Protocol):
class _Overwrites:
__slots__ = ("id", "allow", "deny", "type")
__slots__ = ('id', 'allow', 'deny', 'type')
ROLE = 0
MEMBER = 1
def __init__(self, data: PermissionOverwritePayload):
self.id: int = int(data["id"])
self.allow: int = int(data.get("allow", 0))
self.deny: int = int(data.get("deny", 0))
self.type: OverwriteType = data["type"]
def __init__(self, **kwargs):
self.id = kwargs.pop('id')
self.allow = int(kwargs.pop('allow', 0))
self.deny = int(kwargs.pop('deny', 0))
self.type = kwargs.pop('type')
def _asdict(self) -> PermissionOverwritePayload:
def _asdict(self):
return {
"id": self.id,
"allow": str(self.allow),
"deny": str(self.deny),
"type": self.type,
'id': self.id,
'allow': str(self.allow),
'deny': str(self.deny),
'type': self.type,
}
def is_role(self) -> bool:
@ -215,7 +193,7 @@ class _Overwrites:
return self.type == 1
GCH = TypeVar("GCH", bound="GuildChannel")
GCH = TypeVar('GCH', bound='GuildChannel')
class GuildChannel:
@ -230,6 +208,11 @@ class GuildChannel:
This ABC must also implement :class:`~discord.abc.Snowflake`.
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
Attributes
-----------
name: :class:`str`
@ -247,10 +230,7 @@ class GuildChannel:
name: str
guild: Guild
type: ChannelType
position: int
category_id: Optional[int]
_state: ConnectionState
_overwrites: List[_Overwrites]
if TYPE_CHECKING:
@ -274,13 +254,13 @@ class GuildChannel:
lock_permissions: bool = False,
*,
reason: Optional[str],
) -> None:
):
if position < 0:
raise InvalidArgument("Channel position cannot be less than 0.")
raise InvalidArgument('Channel position cannot be less than 0.')
http = self._state.http
bucket = self._sorting_bucket
channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket]
channels = [c for c in self.guild.channels if c._sorting_bucket == bucket]
channels.sort(key=lambda c: c.position)
@ -297,107 +277,109 @@ class GuildChannel:
payload = []
for index, c in enumerate(channels):
d: Dict[str, Any] = {"id": c.id, "position": index}
d = {'id': c.id, 'position': index}
if parent_id is not _undefined and c.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)
await http.bulk_channel_update(self.guild.id, payload, reason=reason)
self.position = position
if parent_id is not _undefined:
self.category_id = int(parent_id) if parent_id else None
async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]:
async def _edit(self, options, reason):
try:
parent = options.pop("category")
parent = options.pop('category')
except KeyError:
parent_id = _undefined
else:
parent_id = parent and parent.id
try:
options["rate_limit_per_user"] = options.pop("slowmode_delay")
options['rate_limit_per_user'] = options.pop('slowmode_delay')
except KeyError:
pass
try:
rtc_region = options.pop("rtc_region")
rtc_region = options.pop('rtc_region')
except KeyError:
pass
else:
options["rtc_region"] = None if rtc_region is None else str(rtc_region)
options['rtc_region'] = None if rtc_region is None else str(rtc_region)
try:
video_quality_mode = options.pop("video_quality_mode")
video_quality_mode = options.pop('video_quality_mode')
except KeyError:
pass
else:
options["video_quality_mode"] = int(video_quality_mode)
options['video_quality_mode'] = int(video_quality_mode)
lock_permissions = options.pop("sync_permissions", False)
lock_permissions = options.pop('sync_permissions', False)
try:
position = options.pop("position")
position = options.pop('position')
except KeyError:
if parent_id is not _undefined:
if lock_permissions:
category = self.guild.get_channel(parent_id)
if category:
options["permission_overwrites"] = [c._asdict() for c in category._overwrites]
options["parent_id"] = parent_id
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
options['parent_id'] = parent_id
elif lock_permissions and self.category_id is not None:
# if we're syncing permissions on a pre-existing channel category without changing it
# we need to update the permissions to point to the pre-existing category
category = self.guild.get_channel(self.category_id)
if category:
options["permission_overwrites"] = [c._asdict() for c in category._overwrites]
options['permission_overwrites'] = [c._asdict() for c in category._overwrites]
else:
await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason)
overwrites = options.get("overwrites", None)
overwrites = options.get('overwrites', None)
if overwrites is not None:
perms = []
for target, perm in overwrites.items():
if not isinstance(perm, PermissionOverwrite):
raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}")
raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}')
allow, deny = perm.pair()
payload = {
"allow": allow.value,
"deny": deny.value,
"id": target.id,
'allow': allow.value,
'deny': deny.value,
'id': target.id,
}
if isinstance(target, Role):
payload["type"] = _Overwrites.ROLE
payload['type'] = _Overwrites.ROLE
else:
payload["type"] = _Overwrites.MEMBER
payload['type'] = _Overwrites.MEMBER
perms.append(payload)
options["permission_overwrites"] = perms
options['permission_overwrites'] = perms
try:
ch_type = options["type"]
ch_type = options['type']
except KeyError:
pass
else:
if not isinstance(ch_type, ChannelType):
raise InvalidArgument("type field must be of type ChannelType")
options["type"] = ch_type.value
raise InvalidArgument('type field must be of type ChannelType')
options['type'] = ch_type.value
if options:
return await self._state.http.edit_channel(self.id, reason=reason, **options)
data = await self._state.http.edit_channel(self.id, reason=reason, **options)
self._update(self.guild, data)
def _fill_overwrites(self, data: GuildChannelPayload) -> None:
def _fill_overwrites(self, data):
self._overwrites = []
everyone_index = 0
everyone_id = self.guild.id
for index, overridden in enumerate(data.get("permission_overwrites", [])):
overwrite = _Overwrites(overridden)
self._overwrites.append(overwrite)
for index, overridden in enumerate(data.get('permission_overwrites', [])):
overridden_id = int(overridden.pop('id'))
self._overwrites.append(_Overwrites(id=overridden_id, **overridden))
if overwrite.type == _Overwrites.MEMBER:
if overridden['type'] == _Overwrites.MEMBER:
continue
if overwrite.id == everyone_id:
if overridden_id == everyone_id:
# the @everyone role is not guaranteed to be the first one
# in the list of permission overwrites, however the permission
# resolution code kind of requires that it is the first one in
@ -429,7 +411,7 @@ class GuildChannel:
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f"<#{self.id}>"
return f'<#{self.id}>'
@property
def created_at(self) -> datetime:
@ -467,7 +449,7 @@ class GuildChannel:
return PermissionOverwrite()
@property
def overwrites(self) -> Dict[Union[Role, Member], PermissionOverwrite]:
def overwrites(self) -> Mapping[Union[Role, Member], PermissionOverwrite]:
"""Returns all of the channel's overwrites.
This is returned as a dictionary where the key contains the target which
@ -476,7 +458,7 @@ class GuildChannel:
Returns
--------
Dict[Union[:class:`~discord.Role`, :class:`~discord.Member`], :class:`~discord.PermissionOverwrite`]
Mapping[Union[:class:`~discord.Role`, :class:`~discord.Member`], :class:`~discord.PermissionOverwrite`]
The channel's permission overwrites.
"""
ret = {}
@ -506,7 +488,7 @@ class GuildChannel:
If there is no category then this is ``None``.
"""
return self.guild.get_channel(self.category_id) # type: ignore
return self.guild.get_channel(self.category_id)
@property
def permissions_synced(self) -> bool:
@ -517,9 +499,6 @@ class GuildChannel:
.. versionadded:: 1.3
"""
if self.category_id is None:
return False
category = self.guild.get_channel(self.category_id)
return bool(category and category.overwrites == self.overwrites)
@ -538,7 +517,6 @@ class GuildChannel:
someone with that role would have, which is essentially:
- The default role permissions
- The permissions of the role used as a parameter
- The default role permission overwrites
- The permission overwrites of the role used as a parameter
@ -580,26 +558,24 @@ class GuildChannel:
# Handle the role case first
if isinstance(obj, Role):
base.value |= obj._permissions
if base.administrator:
return Permissions.all()
# Apply @everyone allow/deny first since it's special
try:
maybe_everyone = self._overwrites[0]
if maybe_everyone.id == self.guild.id:
base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny)
except IndexError:
pass
if obj.is_default():
overwrite = utils.get(self._overwrites, type=_Overwrites.ROLE, id=obj.id)
if overwrite is not None:
base.handle_overwrite(overwrite.allow, overwrite.deny)
return base
overwrite = utils.get(self._overwrites, type=_Overwrites.ROLE, id=obj.id)
if overwrite is not None:
base.handle_overwrite(overwrite.allow, overwrite.deny)
denies = 0
allows = 0
guild_id = self.guild.id
for overwrite in self._overwrites:
if not overwrite.is_role():
continue
if overwrite.id in (obj.id, guild_id):
denies |= overwrite.deny
allows |= overwrite.allow
base.handle_overwrite(allows, denies)
return base
roles = obj._roles
@ -703,7 +679,14 @@ class GuildChannel:
) -> None:
...
async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions):
async def set_permissions(
self,
target,
*,
overwrite=_undefined,
reason=None,
**permissions
):
r"""|coro|
Sets the channel specific permission overwrites for a target in the
@ -779,18 +762,18 @@ class GuildChannel:
elif isinstance(target, Role):
perm_type = _Overwrites.ROLE
else:
raise InvalidArgument("target parameter must be either Member or Role")
raise InvalidArgument('target parameter must be either Member or Role')
if overwrite is _undefined:
if len(permissions) == 0:
raise InvalidArgument("No overwrite provided.")
raise InvalidArgument('No overwrite provided.')
try:
overwrite = PermissionOverwrite(**permissions)
except (ValueError, TypeError):
raise InvalidArgument("Invalid permissions given to keyword arguments.")
raise InvalidArgument('Invalid permissions given to keyword arguments.')
else:
if len(permissions) > 0:
raise InvalidArgument("Cannot mix overwrite and keyword arguments.")
raise InvalidArgument('Cannot mix overwrite and keyword arguments.')
# TODO: wait for event
@ -800,7 +783,7 @@ class GuildChannel:
(allow, deny) = overwrite.pair()
await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason)
else:
raise InvalidArgument("Invalid overwrite type provided.")
raise InvalidArgument('Invalid overwrite type provided.')
async def _clone_impl(
self: GCH,
@ -809,16 +792,16 @@ class GuildChannel:
name: Optional[str] = None,
reason: Optional[str] = None,
) -> GCH:
base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites]
base_attrs["parent_id"] = self.category_id
base_attrs["name"] = name or self.name
base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites]
base_attrs['parent_id'] = self.category_id
base_attrs['name'] = name or self.name
guild_id = self.guild.id
cls = self.__class__
data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs)
obj = cls(state=self._state, guild=self.guild, data=data)
# temporarily add it to the cache
self.guild._channels[obj.id] = obj # type: ignore
self.guild._channels[obj.id] = obj
return obj
async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH:
@ -964,16 +947,15 @@ class GuildChannel:
if not kwargs:
return
beginning, end = kwargs.get("beginning"), kwargs.get("end")
before, after = kwargs.get("before"), kwargs.get("after")
offset = kwargs.get("offset", 0)
beginning, end = kwargs.get('beginning'), kwargs.get('end')
before, after = kwargs.get('before'), kwargs.get('after')
offset = kwargs.get('offset', 0)
if sum(bool(a) for a in (beginning, end, before, after)) > 1:
raise InvalidArgument("Only one of [before, after, end, beginning] can be used.")
raise InvalidArgument('Only one of [before, after, end, beginning] can be used.')
bucket = self._sorting_bucket
parent_id = kwargs.get("category", MISSING)
parent_id = kwargs.get('category', MISSING)
# fmt: off
channels: List[GuildChannel]
if parent_id not in (MISSING, None):
parent_id = parent_id.id
channels = [
@ -1011,14 +993,14 @@ class GuildChannel:
index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None)
if index is None:
raise InvalidArgument("Could not resolve appropriate move position")
raise InvalidArgument('Could not resolve appropriate move position')
channels.insert(max((index + offset), 0), self)
payload = []
lock_permissions = kwargs.get("sync_permissions", False)
reason = kwargs.get("reason")
lock_permissions = kwargs.get('sync_permissions', False)
reason = kwargs.get('reason')
for index, channel in enumerate(channels):
d = {"id": channel.id, "position": index}
d = {'id': channel.id, 'position': index}
if parent_id is not MISSING and channel.id == self.id:
d.update(parent_id=parent_id, lock_permissions=lock_permissions)
payload.append(d)
@ -1035,7 +1017,7 @@ class GuildChannel:
unique: bool = True,
target_type: Optional[InviteTarget] = None,
target_user: Optional[User] = None,
target_application_id: Optional[int] = None,
target_application_id: Optional[int] = None
) -> Invite:
"""|coro|
@ -1061,11 +1043,11 @@ class GuildChannel:
invite.
reason: Optional[:class:`str`]
The reason for creating this invite. Shows up on the audit log.
target_type: Optional[:class:`.InviteTarget`]
target_type: Optional[:class:`InviteTarget`]
The type of target for the voice channel invite, if any.
.. versionadded:: 2.0
target_user: Optional[:class:`User`]
The user whose stream to display for this invite, required if `target_type` is `TargetType.stream`. The user must be streaming in the channel.
@ -1099,7 +1081,7 @@ class GuildChannel:
unique=unique,
target_type=target_type.value if target_type else None,
target_user_id=target_user.id if target_user else None,
target_application_id=target_application_id,
target_application_id=target_application_id
)
return Invite.from_incomplete(data=data, state=self._state)
@ -1129,7 +1111,7 @@ class GuildChannel:
return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data]
class Messageable:
class Messageable(Protocol):
"""An ABC that details the common operations on a model that can send messages.
The following implement this ABC:
@ -1140,28 +1122,31 @@ class Messageable:
- :class:`~discord.User`
- :class:`~discord.Member`
- :class:`~discord.ext.commands.Context`
- :class:`~discord.Thread`
Note
----
This ABC is not decorated with :func:`typing.runtime_checkable`, so will fail :func:`isinstance`/:func:`issubclass`
checks.
"""
__slots__ = ()
_state: ConnectionState
async def _get_channel(self) -> MessageableChannel:
async def _get_channel(self):
raise NotImplementedError
@overload
async def send(
self,
content: Optional[str] = ...,
content: Optional[str] =...,
*,
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
delete_after: int = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
reference: Union[Message, MessageReference] = ...,
mention_author: bool = ...,
view: View = ...,
) -> Message:
@ -1175,69 +1160,19 @@ class Messageable:
tts: bool = ...,
embed: Embed = ...,
files: List[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
delete_after: int = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
reference: Union[Message, MessageReference] = ...,
mention_author: bool = ...,
view: View = ...,
) -> Message:
...
@overload
async def send(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: List[Embed] = ...,
file: File = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
) -> Message:
...
@overload
async def send(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embeds: List[Embed] = ...,
files: List[File] = ...,
stickers: Sequence[Union[GuildSticker, StickerItem]] = ...,
delete_after: float = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference, PartialMessage] = ...,
mention_author: bool = ...,
view: View = ...,
) -> Message:
...
async def send(
self,
content=None,
*,
tts=None,
embed=None,
embeds=None,
file=None,
files=None,
stickers=None,
delete_after=None,
nonce=None,
allowed_mentions=None,
reference=None,
mention_author=None,
view=None,
):
async def send(self, content=None, *, tts=False, embed=None, file=None,
files=None, delete_after=None, nonce=None,
allowed_mentions=None, reference=None,
mention_author=None, view=None):
"""|coro|
Sends a message to the destination with the content given.
@ -1251,14 +1186,12 @@ class Messageable:
parameter should be used with a :class:`list` of :class:`~discord.File` objects.
**Specifying both parameters will lead to an exception**.
To upload a single embed, the ``embed`` parameter should be used with a
single :class:`~discord.Embed` object. To upload multiple embeds, the ``embeds``
parameter should be used with a :class:`list` of :class:`~discord.Embed` objects.
**Specifying both parameters will lead to an exception**.
If the ``embed`` parameter is provided, it must be of type :class:`~discord.Embed` and
it must be a rich embed type.
Parameters
------------
content: Optional[:class:`str`]
content: :class:`str`
The content of the message to send.
tts: :class:`bool`
Indicates if the message should be sent using text-to-speech.
@ -1285,7 +1218,7 @@ class Messageable:
.. versionadded:: 1.4
reference: Union[:class:`~discord.Message`, :class:`~discord.MessageReference`, :class:`~discord.PartialMessage`]
reference: Union[:class:`~discord.Message`, :class:`~discord.MessageReference`]
A reference to the :class:`~discord.Message` to which you are replying, this can be created using
:meth:`~discord.Message.to_reference` or passed directly as a :class:`~discord.Message`. You can control
whether this mentions the author of the referenced message using the :attr:`~discord.AllowedMentions.replied_user`
@ -1299,12 +1232,6 @@ class Messageable:
.. versionadded:: 1.6
view: :class:`discord.ui.View`
A Discord UI View to add to the message.
embeds: List[:class:`~discord.Embed`]
A list of embeds to upload. Must be a maximum of 10.
.. versionadded:: 2.0
stickers: Sequence[Union[:class:`~discord.GuildSticker`, :class:`~discord.StickerItem`]]
A list of stickers to upload. Must be a maximum of 3.
.. versionadded:: 2.0
@ -1317,9 +1244,8 @@ class Messageable:
~discord.InvalidArgument
The ``files`` list is not of the appropriate size,
you specified both ``file`` and ``files``,
or you specified both ``embed`` and ``embeds``,
or the ``reference`` object is not a :class:`~discord.Message`,
:class:`~discord.MessageReference` or :class:`~discord.PartialMessage`.
or the ``reference`` object is not a :class:`~discord.Message`
or :class:`~discord.MessageReference`.
Returns
---------
@ -1330,21 +1256,9 @@ class Messageable:
channel = await self._get_channel()
state = self._state
content = str(content) if content is not None else None
if embed is not None and embeds is not None:
raise InvalidArgument("cannot pass both embed and embeds parameter to send()")
if embed is not None:
embed = embed.to_dict()
elif embeds is not None:
if len(embeds) > 10:
raise InvalidArgument("embeds parameter must be a list of up to 10 elements")
embeds = [embed.to_dict() for embed in embeds]
if stickers is not None:
stickers = [sticker.id for sticker in stickers]
if allowed_mentions is not None:
if state.allowed_mentions is not None:
allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict()
@ -1355,84 +1269,53 @@ class Messageable:
if mention_author is not None:
allowed_mentions = allowed_mentions or AllowedMentions().to_dict()
allowed_mentions["replied_user"] = bool(mention_author)
allowed_mentions['replied_user'] = bool(mention_author)
if reference is not None:
try:
reference = reference.to_message_reference_dict()
except AttributeError:
raise InvalidArgument(
"reference parameter must be Message, MessageReference, or PartialMessage"
) from None
raise InvalidArgument('reference parameter must be Message or MessageReference') from None
if view:
if not hasattr(view, "__discord_ui_view__"):
raise InvalidArgument(f"view parameter must be View not {view.__class__!r}")
if not hasattr(view, '__discord_ui_view__'):
raise InvalidArgument(f'view parameter must be View not {view.__class__!r}')
components = view.to_components()
else:
components = None
if file is not None and files is not None:
raise InvalidArgument("cannot pass both file and files parameter to send()")
raise InvalidArgument('cannot pass both file and files parameter to send()')
if file is not None:
if not isinstance(file, File):
raise InvalidArgument("file parameter must be File")
raise InvalidArgument('file parameter must be File')
try:
data = await state.http.send_files(
channel.id,
files=[file],
allowed_mentions=allowed_mentions,
content=content,
tts=tts,
embed=embed,
embeds=embeds,
nonce=nonce,
message_reference=reference,
stickers=stickers,
components=components,
)
data = await state.http.send_files(channel.id, files=[file], allowed_mentions=allowed_mentions,
content=content, tts=tts, embed=embed, nonce=nonce,
message_reference=reference, components=components)
finally:
file.close()
elif files is not None:
if len(files) > 10:
raise InvalidArgument("files parameter must be a list of up to 10 elements")
raise InvalidArgument('files parameter must be a list of up to 10 elements')
elif not all(isinstance(file, File) for file in files):
raise InvalidArgument("files parameter must be a list of File")
raise InvalidArgument('files parameter must be a list of File')
try:
data = await state.http.send_files(
channel.id,
files=files,
content=content,
tts=tts,
embed=embed,
embeds=embeds,
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=reference,
stickers=stickers,
components=components,
)
data = await state.http.send_files(channel.id, files=files, content=content, tts=tts,
embed=embed, nonce=nonce, allowed_mentions=allowed_mentions,
message_reference=reference, components=components)
finally:
for f in files:
f.close()
else:
data = await state.http.send_message(
channel.id,
content,
tts=tts,
embed=embed,
embeds=embeds,
nonce=nonce,
allowed_mentions=allowed_mentions,
message_reference=reference,
stickers=stickers,
components=components,
)
data = await state.http.send_message(channel.id, content, tts=tts, embed=embed,
nonce=nonce, allowed_mentions=allowed_mentions,
message_reference=reference, components=components)
ret = state.create_message(channel=channel, data=data)
if view:
@ -1442,7 +1325,7 @@ class Messageable:
await ret.delete(delay=delete_after)
return ret
async def trigger_typing(self) -> None:
async def trigger_typing(self):
"""|coro|
Triggers a *typing* indicator to the destination.
@ -1453,7 +1336,7 @@ class Messageable:
channel = await self._get_channel()
await self._state.http.send_typing(channel.id)
def typing(self) -> Typing:
def typing(self):
"""Returns a context manager that allows you to type for an indefinite period of time.
This is useful for denoting long computations in your bot.
@ -1464,9 +1347,8 @@ class Messageable:
This means that both ``with`` and ``async with`` work with this.
Example Usage: ::
async with channel.typing():
# simulate something heavy
async with channel.typing():
# simulate something heavy
await asyncio.sleep(10)
await channel.send('done!')
@ -1474,7 +1356,7 @@ class Messageable:
"""
return Typing(self)
async def fetch_message(self, id: int, /) -> Message:
async def fetch_message(self, id):
"""|coro|
Retrieves a single :class:`~discord.Message` from the destination.
@ -1503,7 +1385,7 @@ class Messageable:
data = await self._state.http.get_message(channel.id, id)
return self._state.create_message(channel=channel, data=data)
async def pins(self) -> List[Message]:
async def pins(self):
"""|coro|
Retrieves all messages that are currently pinned in the channel.
@ -1530,15 +1412,7 @@ class Messageable:
data = await state.http.pins_from(channel.id)
return [state.create_message(channel=channel, data=m) for m in data]
def history(
self,
*,
limit: Optional[int] = 100,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
around: Optional[SnowflakeTime] = None,
oldest_first: Optional[bool] = None,
) -> HistoryIterator:
def history(self, *, limit=100, before=None, after=None, around=None, oldest_first=None):
"""Returns an :class:`~discord.AsyncIterator` that enables receiving the destination's message history.
You must have :attr:`~discord.Permissions.read_message_history` permissions to use this.
@ -1615,28 +1489,19 @@ class Connectable(Protocol):
"""
__slots__ = ()
_state: ConnectionState
def _get_voice_client_key(self) -> Tuple[int, str]:
def _get_voice_client_key(self):
raise NotImplementedError
def _get_voice_state_pair(self) -> Tuple[int, int]:
def _get_voice_state_pair(self):
raise NotImplementedError
async def connect(
self,
*,
timeout: float = 60.0,
reconnect: bool = True,
cls: Callable[[Client, Connectable], T] = VoiceClient,
) -> T:
async def connect(self, *, timeout: float = 60.0, reconnect: bool = True, cls: Type[T] = VoiceClient) -> T:
"""|coro|
Connects to voice and creates a :class:`VoiceClient` to establish
your connection to the voice server.
This requires :attr:`Intents.voice_states`.
Parameters
-----------
timeout: :class:`float`
@ -1668,13 +1533,13 @@ class Connectable(Protocol):
state = self._state
if state._get_voice_client(key_id):
raise ClientException("Already connected to a voice channel.")
raise ClientException('Already connected to a voice channel.')
client = state._get_client()
voice = cls(client, self)
if not isinstance(voice, VoiceProtocol):
raise TypeError("Type must meet VoiceProtocol abstract base class.")
raise TypeError('Type must meet VoiceProtocol abstract base class.')
state._add_voice_client(key_id, voice)

View File

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, overload
from typing import List, TYPE_CHECKING
from .asset import Asset
from .enums import ActivityType, try_enum
@ -34,12 +34,12 @@ from .partial_emoji import PartialEmoji
from .utils import _get_as_snowflake
__all__ = (
"BaseActivity",
"Activity",
"Streaming",
"Game",
"Spotify",
"CustomActivity",
'BaseActivity',
'Activity',
'Streaming',
'Game',
'Spotify',
'CustomActivity',
)
"""If curious, this is the current schema for an activity.
@ -92,7 +92,6 @@ t.ActivityFlags = {
if TYPE_CHECKING:
from .types.activity import (
Activity as ActivityPayload,
ActivityTimestamps,
ActivityParty,
ActivityAssets,
@ -119,22 +118,19 @@ class BaseActivity:
.. versionadded:: 1.3
"""
__slots__ = ("_created_at",)
__slots__ = ('_created_at',)
def __init__(self, **kwargs):
self._created_at: Optional[float] = kwargs.pop("created_at", None)
self._created_at = kwargs.pop('created_at', None)
@property
def created_at(self) -> Optional[datetime.datetime]:
def created_at(self):
"""Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC.
.. versionadded:: 1.3
"""
if self._created_at is not None:
return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc)
def to_dict(self) -> ActivityPayload:
raise NotImplementedError
return datetime.datetime.utcfromtimestamp(self._created_at / 1000)
class Activity(BaseActivity):
@ -151,17 +147,17 @@ class Activity(BaseActivity):
Attributes
------------
application_id: Optional[:class:`int`]
application_id: :class:`int`
The application ID of the game.
name: Optional[:class:`str`]
name: :class:`str`
The name of the activity.
url: Optional[:class:`str`]
url: :class:`str`
A stream URL that the activity could be doing.
type: :class:`ActivityType`
The type of activity currently being done.
state: Optional[:class:`str`]
state: :class:`str`
The user's current state. For example, "In Game".
details: Optional[:class:`str`]
details: :class:`str`
The detail of the user's current activity.
timestamps: :class:`dict`
A dictionary of timestamps. It contains the following optional keys:
@ -199,61 +195,62 @@ class Activity(BaseActivity):
"""
__slots__ = (
"state",
"details",
"_created_at",
"timestamps",
"assets",
"party",
"flags",
"sync_id",
"session_id",
"type",
"name",
"url",
"application_id",
"emoji",
"buttons",
'state',
'details',
'_created_at',
'timestamps',
'assets',
'party',
'flags',
'sync_id',
'session_id',
'type',
'name',
'url',
'application_id',
'emoji',
'buttons',
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.state: Optional[str] = kwargs.pop("state", None)
self.details: Optional[str] = kwargs.pop("details", None)
self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {})
self.assets: ActivityAssets = kwargs.pop("assets", {})
self.party: ActivityParty = kwargs.pop("party", {})
self.application_id: Optional[int] = _get_as_snowflake(kwargs, "application_id")
self.name: Optional[str] = kwargs.pop("name", None)
self.url: Optional[str] = kwargs.pop("url", None)
self.flags: int = kwargs.pop("flags", 0)
self.sync_id: Optional[str] = kwargs.pop("sync_id", None)
self.session_id: Optional[str] = kwargs.pop("session_id", None)
self.buttons: List[ActivityButton] = kwargs.pop("buttons", [])
self.state = kwargs.pop('state', None)
self.details = kwargs.pop('details', None)
self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {})
self.assets: ActivityAssets = kwargs.pop('assets', {})
self.party: ActivityParty = kwargs.pop('party', {})
self.application_id = _get_as_snowflake(kwargs, 'application_id')
self.name = kwargs.pop('name', None)
self.url = kwargs.pop('url', None)
self.flags = kwargs.pop('flags', 0)
self.sync_id = kwargs.pop('sync_id', None)
self.session_id = kwargs.pop('session_id', None)
self.buttons: List[ActivityButton] = kwargs.pop('buttons', [])
activity_type = kwargs.pop("type", -1)
self.type: ActivityType = (
activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type)
)
activity_type = kwargs.pop('type', -1)
self.type = activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type)
emoji = kwargs.pop("emoji", None)
self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None
emoji = kwargs.pop('emoji', None)
if emoji is not None:
self.emoji = PartialEmoji.from_dict(emoji)
else:
self.emoji = None
def __repr__(self) -> str:
def __repr__(self):
attrs = (
("type", self.type),
("name", self.name),
("url", self.url),
("details", self.details),
("application_id", self.application_id),
("session_id", self.session_id),
("emoji", self.emoji),
('type', self.type),
('name', self.name),
('url', self.url),
('details', self.details),
('application_id', self.application_id),
('session_id', self.session_id),
('emoji', self.emoji),
)
inner = " ".join("%s=%r" % t for t in attrs)
return f"<Activity {inner}>"
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<Activity {inner}>'
def to_dict(self) -> Dict[str, Any]:
ret: Dict[str, Any] = {}
def to_dict(self):
ret = {}
for attr in self.__slots__:
value = getattr(self, attr, None)
if value is None:
@ -263,66 +260,66 @@ class Activity(BaseActivity):
continue
ret[attr] = value
ret["type"] = int(self.type)
ret['type'] = int(self.type)
if self.emoji:
ret["emoji"] = self.emoji.to_dict()
ret['emoji'] = self.emoji.to_dict()
return ret
@property
def start(self) -> Optional[datetime.datetime]:
def start(self):
"""Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable."""
try:
timestamp = self.timestamps["start"] / 1000
timestamp = self.timestamps['start'] / 1000
except KeyError:
return None
else:
return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=datetime.timezone.utc)
@property
def end(self) -> Optional[datetime.datetime]:
def end(self):
"""Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable."""
try:
timestamp = self.timestamps["end"] / 1000
timestamp = self.timestamps['end'] / 1000
except KeyError:
return None
else:
return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=datetime.timezone.utc)
@property
def large_image_url(self) -> Optional[str]:
def large_image_url(self):
"""Optional[:class:`str`]: Returns a URL pointing to the large image asset of this activity if applicable."""
if self.application_id is None:
return None
try:
large_image = self.assets["large_image"]
large_image = self.assets['large_image']
except KeyError:
return None
else:
return Asset.BASE + f"/app-assets/{self.application_id}/{large_image}.png"
return Asset.BASE + f'/app-assets/{self.application_id}/{large_image}.png'
@property
def small_image_url(self) -> Optional[str]:
def small_image_url(self):
"""Optional[:class:`str`]: Returns a URL pointing to the small image asset of this activity if applicable."""
if self.application_id is None:
return None
try:
small_image = self.assets["small_image"]
small_image = self.assets['small_image']
except KeyError:
return None
else:
return Asset.BASE + f"/app-assets/{self.application_id}/{small_image}.png"
return Asset.BASE + f'/app-assets/{self.application_id}/{small_image}.png'
@property
def large_image_text(self) -> Optional[str]:
def large_image_text(self):
"""Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable."""
return self.assets.get("large_text", None)
return self.assets.get('large_text', None)
@property
def small_image_text(self) -> Optional[str]:
def small_image_text(self):
"""Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable."""
return self.assets.get("small_text", None)
return self.assets.get('small_text', None)
class Game(BaseActivity):
@ -359,23 +356,23 @@ class Game(BaseActivity):
The game's name.
"""
__slots__ = ("name", "_end", "_start")
__slots__ = ('name', '_end', '_start')
def __init__(self, name: str, **extra):
def __init__(self, name, **extra):
super().__init__(**extra)
self.name: str = name
self.name = name
try:
timestamps: ActivityTimestamps = extra["timestamps"]
timestamps: ActivityTimestamps = extra['timestamps']
except KeyError:
self._start = 0
self._end = 0
else:
self._start = timestamps.get("start", 0)
self._end = timestamps.get("end", 0)
self._start = timestamps.get('start', 0)
self._end = timestamps.get('end', 0)
@property
def type(self) -> ActivityType:
def type(self):
""":class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`.
It always returns :attr:`ActivityType.playing`.
@ -383,32 +380,32 @@ class Game(BaseActivity):
return ActivityType.playing
@property
def start(self) -> Optional[datetime.datetime]:
def start(self):
"""Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable."""
if self._start:
return datetime.datetime.fromtimestamp(self._start / 1000, tz=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(self._start / 1000).replace(tzinfo=datetime.timezone.utc)
return None
@property
def end(self) -> Optional[datetime.datetime]:
def end(self):
"""Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable."""
if self._end:
return datetime.datetime.fromtimestamp(self._end / 1000, tz=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(self._end / 1000).replace(tzinfo=datetime.timezone.utc)
return None
def __str__(self) -> str:
def __str__(self):
return str(self.name)
def __repr__(self) -> str:
return f"<Game name={self.name!r}>"
def __repr__(self):
return f'<Game name={self.name!r}>'
def to_dict(self) -> Dict[str, Any]:
timestamps: Dict[str, Any] = {}
def to_dict(self):
timestamps = {}
if self._start:
timestamps["start"] = self._start
timestamps['start'] = self._start
if self._end:
timestamps["end"] = self._end
timestamps['end'] = self._end
# fmt: off
return {
@ -418,13 +415,13 @@ class Game(BaseActivity):
}
# fmt: on
def __eq__(self, other: Any) -> bool:
def __eq__(self, other):
return isinstance(other, Game) and other.name == self.name
def __ne__(self, other: Any) -> bool:
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self) -> int:
def __hash__(self):
return hash(self.name)
@ -453,7 +450,7 @@ class Streaming(BaseActivity):
Attributes
-----------
platform: Optional[:class:`str`]
platform: :class:`str`
Where the user is streaming from (ie. YouTube, Twitch).
.. versionadded:: 1.3
@ -473,30 +470,30 @@ class Streaming(BaseActivity):
A dictionary comprising of similar keys than those in :attr:`Activity.assets`.
"""
__slots__ = ("platform", "name", "game", "url", "details", "assets")
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
def __init__(self, *, name: Optional[str], url: str, **extra: Any):
def __init__(self, *, name, url, **extra):
super().__init__(**extra)
self.platform: Optional[str] = name
self.name: Optional[str] = extra.pop("details", name)
self.game: Optional[str] = extra.pop("state", None)
self.url: str = url
self.details: Optional[str] = extra.pop("details", self.name) # compatibility
self.assets: ActivityAssets = extra.pop("assets", {})
self.platform = name
self.name = extra.pop('details', name)
self.game = extra.pop('state', None)
self.url = url
self.details = extra.pop('details', self.name) # compatibility
self.assets: ActivityAssets = extra.pop('assets', {})
@property
def type(self) -> ActivityType:
def type(self):
""":class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`.
It always returns :attr:`ActivityType.streaming`.
"""
return ActivityType.streaming
def __str__(self) -> str:
def __str__(self):
return str(self.name)
def __repr__(self) -> str:
return f"<Streaming name={self.name!r}>"
def __repr__(self):
return f'<Streaming name={self.name!r}>'
@property
def twitch_name(self):
@ -507,15 +504,15 @@ class Streaming(BaseActivity):
"""
try:
name = self.assets["large_image"]
name = self.assets['large_image']
except KeyError:
return None
else:
return name[7:] if name[:7] == "twitch:" else None
return name[7:] if name[:7] == 'twitch:' else None
def to_dict(self) -> Dict[str, Any]:
def to_dict(self):
# fmt: off
ret: Dict[str, Any] = {
ret = {
'type': ActivityType.streaming.value,
'name': str(self.name),
'url': str(self.url),
@ -523,16 +520,16 @@ class Streaming(BaseActivity):
}
# fmt: on
if self.details:
ret["details"] = self.details
ret['details'] = self.details
return ret
def __eq__(self, other: Any) -> bool:
def __eq__(self, other):
return isinstance(other, Streaming) and other.name == self.name and other.url == self.url
def __ne__(self, other: Any) -> bool:
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self) -> int:
def __hash__(self):
return hash(self.name)
@ -559,20 +556,20 @@ class Spotify:
Returns the string 'Spotify'.
"""
__slots__ = ("_state", "_details", "_timestamps", "_assets", "_party", "_sync_id", "_session_id", "_created_at")
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at')
def __init__(self, **data):
self._state: str = data.pop("state", "")
self._details: str = data.pop("details", "")
self._timestamps: Dict[str, int] = data.pop("timestamps", {})
self._assets: ActivityAssets = data.pop("assets", {})
self._party: ActivityParty = data.pop("party", {})
self._sync_id: str = data.pop("sync_id")
self._session_id: str = data.pop("session_id")
self._created_at: Optional[float] = data.pop("created_at", None)
self._state = data.pop('state', None)
self._details = data.pop('details', None)
self._timestamps = data.pop('timestamps', {})
self._assets = data.pop('assets', {})
self._party = data.pop('party', {})
self._sync_id = data.pop('sync_id')
self._session_id = data.pop('session_id')
self._created_at = data.pop('created_at', None)
@property
def type(self) -> ActivityType:
def type(self):
""":class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`.
It always returns :attr:`ActivityType.listening`.
@ -580,47 +577,47 @@ class Spotify:
return ActivityType.listening
@property
def created_at(self) -> Optional[datetime.datetime]:
def created_at(self):
"""Optional[:class:`datetime.datetime`]: When the user started listening in UTC.
.. versionadded:: 1.3
"""
if self._created_at is not None:
return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(self._created_at / 1000)
@property
def colour(self) -> Colour:
def colour(self):
""":class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`.
There is an alias for this named :attr:`color`"""
return Colour(0x1DB954)
@property
def color(self) -> Colour:
def color(self):
""":class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`.
There is an alias for this named :attr:`colour`"""
return self.colour
def to_dict(self) -> Dict[str, Any]:
def to_dict(self):
return {
"flags": 48, # SYNC | PLAY
"name": "Spotify",
"assets": self._assets,
"party": self._party,
"sync_id": self._sync_id,
"session_id": self._session_id,
"timestamps": self._timestamps,
"details": self._details,
"state": self._state,
'flags': 48, # SYNC | PLAY
'name': 'Spotify',
'assets': self._assets,
'party': self._party,
'sync_id': self._sync_id,
'session_id': self._session_id,
'timestamps': self._timestamps,
'details': self._details,
'state': self._state,
}
@property
def name(self) -> str:
def name(self):
""":class:`str`: The activity's name. This will always return "Spotify"."""
return "Spotify"
return 'Spotify'
def __eq__(self, other: Any) -> bool:
def __eq__(self, other):
return (
isinstance(other, Spotify)
and other._session_id == self._session_id
@ -628,30 +625,30 @@ class Spotify:
and other.start == self.start
)
def __ne__(self, other: Any) -> bool:
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self) -> int:
def __hash__(self):
return hash(self._session_id)
def __str__(self) -> str:
return "Spotify"
def __str__(self):
return 'Spotify'
def __repr__(self) -> str:
return f"<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>"
def __repr__(self):
return f'<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>'
@property
def title(self) -> str:
def title(self):
""":class:`str`: The title of the song being played."""
return self._details
@property
def artists(self) -> List[str]:
def artists(self):
"""List[:class:`str`]: The artists of the song being played."""
return self._state.split("; ")
return self._state.split('; ')
@property
def artist(self) -> str:
def artist(self):
""":class:`str`: The artist of the song being played.
This does not attempt to split the artist information into
@ -660,51 +657,43 @@ class Spotify:
return self._state
@property
def album(self) -> str:
def album(self):
""":class:`str`: The album that the song being played belongs to."""
return self._assets.get("large_text", "")
return self._assets.get('large_text', '')
@property
def album_cover_url(self) -> str:
def album_cover_url(self):
""":class:`str`: The album cover image URL from Spotify's CDN."""
large_image = self._assets.get("large_image", "")
if large_image[:8] != "spotify:":
return ""
large_image = self._assets.get('large_image', '')
if large_image[:8] != 'spotify:':
return ''
album_image_id = large_image[8:]
return "https://i.scdn.co/image/" + album_image_id
return 'https://i.scdn.co/image/' + album_image_id
@property
def track_id(self) -> str:
def track_id(self):
""":class:`str`: The track ID used by Spotify to identify this song."""
return self._sync_id
@property
def track_url(self) -> str:
""":class:`str`: The track URL to listen on Spotify.
.. versionadded:: 2.0
"""
return f"https://open.spotify.com/track/{self.track_id}"
@property
def start(self) -> datetime.datetime:
def start(self):
""":class:`datetime.datetime`: When the user started playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps["start"] / 1000, tz=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(self._timestamps['start'] / 1000)
@property
def end(self) -> datetime.datetime:
def end(self):
""":class:`datetime.datetime`: When the user will stop playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps["end"] / 1000, tz=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(self._timestamps['end'] / 1000)
@property
def duration(self) -> datetime.timedelta:
def duration(self):
""":class:`datetime.timedelta`: The duration of the song being played."""
return self.end - self.start
@property
def party_id(self) -> str:
def party_id(self):
""":class:`str`: The party ID of the listening party."""
return self._party.get("id", "")
return self._party.get('id', '')
class CustomActivity(BaseActivity):
@ -738,16 +727,15 @@ class CustomActivity(BaseActivity):
The emoji to pass to the activity, if any.
"""
__slots__ = ("name", "emoji", "state")
__slots__ = ('name', 'emoji', 'state')
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any):
def __init__(self, name, *, emoji=None, **extra):
super().__init__(**extra)
self.name: Optional[str] = name
self.state: Optional[str] = extra.pop("state", None)
if self.name == "Custom Status":
self.name = name
self.state = extra.pop('state', None)
if self.name == 'Custom Status':
self.name = self.state
self.emoji: Optional[PartialEmoji]
if emoji is None:
self.emoji = emoji
elif isinstance(emoji, dict):
@ -757,89 +745,74 @@ class CustomActivity(BaseActivity):
elif isinstance(emoji, PartialEmoji):
self.emoji = emoji
else:
raise TypeError(f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.")
raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.')
@property
def type(self) -> ActivityType:
def type(self):
""":class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`.
It always returns :attr:`ActivityType.custom`.
"""
return ActivityType.custom
def to_dict(self) -> Dict[str, Any]:
def to_dict(self):
if self.name == self.state:
o = {
"type": ActivityType.custom.value,
"state": self.name,
"name": "Custom Status",
'type': ActivityType.custom.value,
'state': self.name,
'name': 'Custom Status',
}
else:
o = {
"type": ActivityType.custom.value,
"name": self.name,
'type': ActivityType.custom.value,
'name': self.name,
}
if self.emoji:
o["emoji"] = self.emoji.to_dict()
o['emoji'] = self.emoji.to_dict()
return o
def __eq__(self, other: Any) -> bool:
def __eq__(self, other):
return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool:
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self) -> int:
def __hash__(self):
return hash((self.name, str(self.emoji)))
def __str__(self) -> str:
def __str__(self):
if self.emoji:
if self.name:
return f"{self.emoji} {self.name}"
return f'{self.emoji} {self.name}'
return str(self.emoji)
else:
return str(self.name)
def __repr__(self) -> str:
return f"<CustomActivity name={self.name!r} emoji={self.emoji!r}>"
def __repr__(self):
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'
ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify]
@overload
def create_activity(data: ActivityPayload) -> ActivityTypes:
...
@overload
def create_activity(data: None) -> None:
...
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
def create_activity(data):
if not data:
return None
game_type = try_enum(ActivityType, data.get("type", -1))
game_type = try_enum(ActivityType, data.get('type', -1))
if game_type is ActivityType.playing:
if "application_id" in data or "session_id" in data:
if 'application_id' in data or 'session_id' in data:
return Activity(**data)
return Game(**data)
elif game_type is ActivityType.custom:
try:
name = data.pop("name")
name = data.pop('name')
except KeyError:
return Activity(**data)
else:
# we removed the name key from data already
return CustomActivity(name=name, **data) # type: ignore
return CustomActivity(name=name, **data)
elif game_type is ActivityType.streaming:
if "url" in data:
# the url won't be None here
return Streaming(**data) # type: ignore
if 'url' in data:
return Streaming(**data)
return Activity(**data)
elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data:
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
return Spotify(**data)
return Activity(**data)

View File

@ -40,8 +40,8 @@ if TYPE_CHECKING:
from .state import ConnectionState
__all__ = (
"AppInfo",
"PartialAppInfo",
'AppInfo',
'PartialAppInfo',
)
@ -115,58 +115,58 @@ class AppInfo:
"""
__slots__ = (
"_state",
"description",
"id",
"name",
"rpc_origins",
"bot_public",
"bot_require_code_grant",
"owner",
"_icon",
"summary",
"verify_key",
"team",
"guild_id",
"primary_sku_id",
"slug",
"_cover_image",
"terms_of_service_url",
"privacy_policy_url",
'_state',
'description',
'id',
'name',
'rpc_origins',
'bot_public',
'bot_require_code_grant',
'owner',
'_icon',
'summary',
'verify_key',
'team',
'guild_id',
'primary_sku_id',
'slug',
'_cover_image',
'terms_of_service_url',
'privacy_policy_url',
)
def __init__(self, state: ConnectionState, data: AppInfoPayload):
from .team import Team
self._state: ConnectionState = state
self.id: int = int(data["id"])
self.name: str = data["name"]
self.description: str = data["description"]
self._icon: Optional[str] = data["icon"]
self.rpc_origins: List[str] = data["rpc_origins"]
self.bot_public: bool = data["bot_public"]
self.bot_require_code_grant: bool = data["bot_require_code_grant"]
self.owner: User = state.create_user(data["owner"])
self.id: int = int(data['id'])
self.name: str = data['name']
self.description: str = data['description']
self._icon: Optional[str] = data['icon']
self.rpc_origins: List[str] = data['rpc_origins']
self.bot_public: bool = data['bot_public']
self.bot_require_code_grant: bool = data['bot_require_code_grant']
self.owner: User = state.store_user(data['owner'])
team: Optional[TeamPayload] = data.get("team")
team: Optional[TeamPayload] = data.get('team')
self.team: Optional[Team] = Team(state, team) if team else None
self.summary: str = data["summary"]
self.verify_key: str = data["verify_key"]
self.summary: str = data['summary']
self.verify_key: str = data['verify_key']
self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id')
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, "primary_sku_id")
self.slug: Optional[str] = data.get("slug")
self._cover_image: Optional[str] = data.get("cover_image")
self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id')
self.slug: Optional[str] = data.get('slug')
self._cover_image: Optional[str] = data.get('cover_image')
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
f"description={self.description!r} public={self.bot_public} "
f"owner={self.owner!r}>"
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
f'description={self.description!r} public={self.bot_public} '
f'owner={self.owner!r}>'
)
@property
@ -174,7 +174,7 @@ class AppInfo:
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path="app")
return Asset._from_icon(self._state, self.id, self._icon, path='app')
@property
def cover_image(self) -> Optional[Asset]:
@ -195,9 +195,8 @@ class AppInfo:
"""
return self._state._get_guild(self.guild_id)
class PartialAppInfo:
"""Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite`
"""Represents a partial AppInfo given by :func:`~GuildChannel.create_invite`
.. versionadded:: 2.0
@ -223,37 +222,26 @@ class PartialAppInfo:
The application's privacy policy URL, if set.
"""
__slots__ = (
"_state",
"id",
"name",
"description",
"rpc_origins",
"summary",
"verify_key",
"terms_of_service_url",
"privacy_policy_url",
"_icon",
)
__slots__ = ('_state', 'id', 'name', 'description', 'rpc_origins', 'summary', 'verify_key', 'terms_of_service_url', 'privacy_policy_url', '_icon')
def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload):
self._state: ConnectionState = state
self.id: int = int(data["id"])
self.name: str = data["name"]
self._icon: Optional[str] = data.get("icon")
self.description: str = data["description"]
self.rpc_origins: Optional[List[str]] = data.get("rpc_origins")
self.summary: str = data["summary"]
self.verify_key: str = data["verify_key"]
self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url")
self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url")
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: Optional[str] = data.get('icon')
self.description: str = data['description']
self.rpc_origins: Optional[List[str]] = data.get('rpc_origins')
self.summary: str = data['summary']
self.verify_key: str = data['verify_key']
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
def __repr__(self) -> str:
return f"<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>"
return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>'
@property
def icon(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path="app")
return Asset._from_icon(self._state, self.id, self._icon, path='app')

View File

@ -33,11 +33,13 @@ from . import utils
import yarl
__all__ = ("Asset",)
__all__ = (
'Asset',
)
if TYPE_CHECKING:
ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"]
ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"]
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
@ -45,7 +47,6 @@ VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
MISSING = utils.MISSING
class AssetMixin:
url: str
_state: Optional[Any]
@ -70,7 +71,7 @@ class AssetMixin:
The content of the asset.
"""
if self._state is None:
raise DiscordException("Invalid state (no ConnectionState provided)")
raise DiscordException('Invalid state (no ConnectionState provided)')
return await self._state.http.get_from_cdn(self.url)
@ -111,7 +112,7 @@ class AssetMixin:
fp.seek(0)
return written
else:
with open(fp, "wb") as f:
with open(fp, 'wb') as f:
return f.write(data)
@ -142,13 +143,13 @@ class Asset(AssetMixin):
"""
__slots__: Tuple[str, ...] = (
"_state",
"_url",
"_animated",
"_key",
'_state',
'_url',
'_animated',
'_key',
)
BASE = "https://cdn.discordapp.com"
BASE = 'https://cdn.discordapp.com'
def __init__(self, state, *, url: str, key: str, animated: bool = False):
self._state = state
@ -160,29 +161,18 @@ class Asset(AssetMixin):
def _from_default_avatar(cls, state, index: int) -> Asset:
return cls(
state,
url=f"{cls.BASE}/embed/avatars/{index}.png",
url=f'{cls.BASE}/embed/avatars/{index}.png',
key=str(index),
animated=False,
)
@classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
animated = avatar.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
state,
url=f"{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024",
key=avatar,
animated=animated,
)
@classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset:
animated = avatar.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024",
url=f'{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024',
key=avatar,
animated=animated,
)
@ -191,7 +181,7 @@ class Asset(AssetMixin):
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
return cls(
state,
url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024",
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024',
key=icon_hash,
animated=False,
)
@ -200,7 +190,7 @@ class Asset(AssetMixin):
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
return cls(
state,
url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024",
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024',
key=cover_image_hash,
animated=False,
)
@ -209,42 +199,31 @@ class Asset(AssetMixin):
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
return cls(
state,
url=f"{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024",
url=f'{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024',
key=image,
animated=False,
)
@classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
animated = icon_hash.startswith("a_")
format = "gif" if animated else "png"
animated = icon_hash.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
state,
url=f"{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024",
url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024',
key=icon_hash,
animated=animated,
)
@classmethod
def _from_sticker_banner(cls, state, banner: int) -> Asset:
def _from_sticker(cls, state, sticker_id: int, sticker_hash: str) -> Asset:
return cls(
state,
url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png",
key=str(banner),
url=f'{cls.BASE}/stickers/{sticker_id}/{sticker_hash}.png?size=1024',
key=sticker_hash,
animated=False,
)
@classmethod
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
animated = banner_hash.startswith("a_")
format = "gif" if animated else "png"
return cls(
state,
url=f"{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512",
key=banner_hash,
animated=animated,
)
def __str__(self) -> str:
return self._url
@ -252,8 +231,8 @@ class Asset(AssetMixin):
return len(self._url)
def __repr__(self):
shorten = self._url.replace(self.BASE, "")
return f"<Asset url={shorten!r}>"
shorten = self._url.replace(self.BASE, '')
return f'<Asset url={shorten!r}>'
def __eq__(self, other):
return isinstance(other, Asset) and self._url == other._url
@ -277,7 +256,6 @@ class Asset(AssetMixin):
def replace(
self,
*,
size: int = MISSING,
format: ValidAssetFormatTypes = MISSING,
static_format: ValidStaticFormatTypes = MISSING,
@ -311,21 +289,20 @@ class Asset(AssetMixin):
if format is not MISSING:
if self._animated:
if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}")
url = url.with_path(f"{path}.{format}")
elif static_format is MISSING:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
else:
if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}")
url = url.with_path(f"{path}.{format}")
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{format}')
if static_format is not MISSING and not self._animated:
if static_format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f"static_format must be one of {VALID_STATIC_FORMATS}")
url = url.with_path(f"{path}.{static_format}")
raise InvalidArgument(f'static_format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{static_format}')
if size is not MISSING:
if not utils.valid_icon_size(size):
raise InvalidArgument("size must be a power of 2 between 16 and 4096")
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
url = url.with_query(size=size)
else:
url = url.with_query(url.raw_query_string)
@ -333,7 +310,7 @@ class Asset(AssetMixin):
url = str(url)
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_size(self, size: int, /) -> Asset:
def with_size(self, size: int) -> Asset:
"""Returns a new asset with the specified size.
Parameters
@ -352,12 +329,12 @@ class Asset(AssetMixin):
The new updated asset.
"""
if not utils.valid_icon_size(size):
raise InvalidArgument("size must be a power of 2 between 16 and 4096")
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_format(self, format: ValidAssetFormatTypes, /) -> Asset:
def with_format(self, format: ValidAssetFormatTypes) -> Asset:
"""Returns a new asset with the specified format.
Parameters
@ -378,17 +355,17 @@ class Asset(AssetMixin):
if self._animated:
if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}")
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
else:
if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}")
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
url = yarl.URL(self._url)
path, _ = os.path.splitext(url.path)
url = str(url.with_path(f"{path}.{format}").with_query(url.raw_query_string))
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset:
def with_static_format(self, format: ValidStaticFormatTypes) -> Asset:
"""Returns a new asset with the specified static format.
This only changes the format if the underlying asset is

View File

@ -35,9 +35,9 @@ from .object import Object
from .permissions import PermissionOverwrite, Permissions
__all__ = (
"AuditLogDiff",
"AuditLogChanges",
"AuditLogEntry",
'AuditLogDiff',
'AuditLogChanges',
'AuditLogEntry',
)
@ -49,17 +49,12 @@ if TYPE_CHECKING:
from .guild import Guild
from .member import Member
from .role import Role
from .types.audit_log import (
AuditLogChange as AuditLogChangePayload,
AuditLogEntry as AuditLogEntryPayload,
)
from .types.audit_log import AuditLogChange as AuditLogChangePayload
from .types.audit_log import AuditLogEntry as AuditLogEntryPayload
from .types.channel import PermissionOverwrite as PermissionOverwritePayload
from .types.role import Role as RolePayload
from .types.snowflake import Snowflake
from .user import User
from .stage_instance import StageInstance
from .sticker import GuildSticker
from .threads import Thread
def _transform_permissions(entry: AuditLogEntry, data: str) -> Permissions:
@ -74,22 +69,22 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int:
return int(data)
def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Union[abc.GuildChannel, Object]]:
def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Object]:
if data is None:
return None
return entry.guild.get_channel(int(data)) or Object(id=data)
def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
def _transform_owner_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
if data is None:
return None
return entry._get_member(int(data))
def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]:
def _transform_inviter_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
if data is None:
return None
return entry._state._get_guild(data)
return entry._get_member(int(data))
def _transform_overwrites(
@ -97,16 +92,16 @@ def _transform_overwrites(
) -> List[Tuple[Object, PermissionOverwrite]]:
overwrites = []
for elem in data:
allow = Permissions(int(elem["allow"]))
deny = Permissions(int(elem["deny"]))
allow = Permissions(elem['allow'])
deny = Permissions(elem['deny'])
ow = PermissionOverwrite.from_pair(allow, deny)
ow_type = elem["type"]
ow_id = int(elem["id"])
ow_type = elem['type']
ow_id = int(elem['id'])
target = None
if ow_type == "0":
if ow_type == '0':
target = entry.guild.get_role(ow_id)
elif ow_type == "1":
elif ow_type == '1':
target = entry._get_member(ow_id)
if target is None:
@ -138,7 +133,7 @@ def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]
return _transform
T = TypeVar("T", bound=enums.Enum)
T = TypeVar('T', bound=enums.Enum)
def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
@ -148,13 +143,6 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
return _transform
def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]:
if entry.action.name.startswith("sticker_"):
return enums.try_enum(enums.StickerType, data)
else:
return enums.try_enum(enums.ChannelType, data)
class AuditLogDiff:
def __len__(self) -> int:
return len(self.__dict__)
@ -163,8 +151,8 @@ class AuditLogDiff:
yield from self.__dict__.items()
def __repr__(self) -> str:
values = " ".join("%s=%r" % item for item in self.__dict__.items())
return f"<AuditLogDiff {values}>"
values = ' '.join('%s=%r' % item for item in self.__dict__.items())
return f'<AuditLogDiff {values}>'
if TYPE_CHECKING:
@ -188,8 +176,8 @@ class AuditLogChanges:
'permissions': (None, _transform_permissions),
'id': (None, _transform_snowflake),
'color': ('colour', _transform_color),
'owner_id': ('owner', _transform_member_id),
'inviter_id': ('inviter', _transform_member_id),
'owner_id': ('owner', _transform_owner_id),
'inviter_id': ('inviter', _transform_inviter_id),
'channel_id': ('channel', _transform_channel),
'afk_channel_id': ('afk_channel', _transform_channel),
'system_channel_id': ('system_channel', _transform_channel),
@ -203,15 +191,12 @@ class AuditLogChanges:
'icon_hash': ('icon', _transform_icon),
'avatar_hash': ('avatar', _transform_avatar),
'rate_limit_per_user': ('slowmode_delay', None),
'guild_id': ('guild', _transform_guild_id),
'tags': ('emoji', None),
'default_message_notifications': ('default_notifications', _enum_transformer(enums.NotificationLevel)),
'region': (None, _enum_transformer(enums.VoiceRegion)),
'rtc_region': (None, _enum_transformer(enums.VoiceRegion)),
'video_quality_mode': (None, _enum_transformer(enums.VideoQualityMode)),
'privacy_level': (None, _enum_transformer(enums.StagePrivacyLevel)),
'format_type': (None, _enum_transformer(enums.StickerFormatType)),
'type': (None, _transform_type),
'type': (None, _enum_transformer(enums.ChannelType)),
}
# fmt: on
@ -220,14 +205,14 @@ class AuditLogChanges:
self.after = AuditLogDiff()
for elem in data:
attr = elem["key"]
attr = elem['key']
# special cases for role add/remove
if attr == "$add":
self._handle_role(self.before, self.after, entry, elem["new_value"]) # type: ignore
if attr == '$add':
self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore
continue
elif attr == "$remove":
self._handle_role(self.after, self.before, entry, elem["new_value"]) # type: ignore
elif attr == '$remove':
self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore
continue
try:
@ -241,7 +226,7 @@ class AuditLogChanges:
transformer: Optional[Transformer]
try:
before = elem["old_value"]
before = elem['old_value']
except KeyError:
before = None
else:
@ -251,7 +236,7 @@ class AuditLogChanges:
setattr(self.before, attr, before)
try:
after = elem["new_value"]
after = elem['new_value']
except KeyError:
after = None
else:
@ -261,36 +246,34 @@ class AuditLogChanges:
setattr(self.after, attr, after)
# add an alias
if hasattr(self.after, "colour"):
if hasattr(self.after, 'colour'):
self.after.color = self.after.colour
self.before.color = self.before.colour
if hasattr(self.after, "expire_behavior"):
if hasattr(self.after, 'expire_behavior'):
self.after.expire_behaviour = self.after.expire_behavior
self.before.expire_behaviour = self.before.expire_behavior
def __repr__(self) -> str:
return f"<AuditLogChanges before={self.before!r} after={self.after!r}>"
return f'<AuditLogChanges before={self.before!r} after={self.after!r}>'
def _handle_role(
self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]
) -> None:
if not hasattr(first, "roles"):
setattr(first, "roles", [])
def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None:
if not hasattr(first, 'roles'):
setattr(first, 'roles', [])
data = []
g: Guild = entry.guild # type: ignore
for e in elem:
role_id = int(e["id"])
role_id = int(e['id'])
role = g.get_role(role_id)
if role is None:
role = Object(id=role_id)
role.name = e["name"] # type: ignore
role.name = e['name'] # type: ignore
data.append(role)
setattr(second, "roles", data)
setattr(second, 'roles', data)
class _AuditLogProxyMemberPrune:
@ -335,10 +318,6 @@ class AuditLogEntry(Hashable):
Returns the entry's hash.
.. describe:: int(x)
Returns the entry's ID.
.. versionchanged:: 1.7
Audit log entries are now comparable and hashable.
@ -370,56 +349,56 @@ class AuditLogEntry(Hashable):
self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None:
self.action = enums.try_enum(enums.AuditLogAction, data["action_type"])
self.id = int(data["id"])
self.action = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id = int(data['id'])
# this key is technically not usually present
self.reason = data.get("reason")
self.extra = data.get("options")
self.reason = data.get('reason')
self.extra = data.get('options')
if isinstance(self.action, enums.AuditLogAction) and self.extra:
if self.action is enums.AuditLogAction.member_prune:
# member prune has two keys with useful information
self.extra: _AuditLogProxyMemberPrune = type(
"_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()}
'_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()}
)()
elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete:
channel_id = int(self.extra["channel_id"])
channel_id = int(self.extra['channel_id'])
elems = {
"count": int(self.extra["count"]),
"channel": self.guild.get_channel(channel_id) or Object(id=channel_id),
'count': int(self.extra['count']),
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id),
}
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type("_AuditLogProxy", (), elems)()
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type('_AuditLogProxy', (), elems)()
elif self.action is enums.AuditLogAction.member_disconnect:
# The member disconnect action has a dict with some information
elems = {
"count": int(self.extra["count"]),
'count': int(self.extra['count']),
}
self.extra: _AuditLogProxyMemberDisconnect = type("_AuditLogProxy", (), elems)()
elif self.action.name.endswith("pin"):
self.extra: _AuditLogProxyMemberDisconnect = type('_AuditLogProxy', (), elems)()
elif self.action.name.endswith('pin'):
# the pin actions have a dict with some information
channel_id = int(self.extra["channel_id"])
channel_id = int(self.extra['channel_id'])
elems = {
"channel": self.guild.get_channel(channel_id) or Object(id=channel_id),
"message_id": int(self.extra["message_id"]),
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id),
'message_id': int(self.extra['message_id']),
}
self.extra: _AuditLogProxyPinAction = type("_AuditLogProxy", (), elems)()
elif self.action.name.startswith("overwrite_"):
self.extra: _AuditLogProxyPinAction = type('_AuditLogProxy', (), elems)()
elif self.action.name.startswith('overwrite_'):
# the overwrite_ actions have a dict with some information
instance_id = int(self.extra["id"])
the_type = self.extra.get("type")
if the_type == "1":
instance_id = int(self.extra['id'])
the_type = self.extra.get('type')
if the_type == '1':
self.extra = self._get_member(instance_id)
elif the_type == "0":
elif the_type == '0':
role = self.guild.get_role(instance_id)
if role is None:
role = Object(id=instance_id)
role.name = self.extra.get("role_name") # type: ignore
role.name = self.extra.get('role_name') # type: ignore
self.extra: Role = role
elif self.action.name.startswith("stage_instance"):
channel_id = int(self.extra["channel_id"])
elems = {"channel": self.guild.get_channel(channel_id) or Object(id=channel_id)}
self.extra: _AuditLogProxyStageInstanceAction = type("_AuditLogProxy", (), elems)()
elif self.action.name.startswith('stage_instance'):
channel_id = int(self.extra['channel_id'])
elems = {'channel': self.guild.get_channel(channel_id) or Object(id=channel_id)}
self.extra: _AuditLogProxyStageInstanceAction = type('_AuditLogProxy', (), elems)()
# fmt: off
self.extra: Union[
@ -438,16 +417,16 @@ class AuditLogEntry(Hashable):
# where new_value and old_value are not guaranteed to be there depending
# on the action type, so let's just fetch it for now and only turn it
# into meaningful data when requested
self._changes = data.get("changes", [])
self._changes = data.get('changes', [])
self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore
self._target_id = utils._get_as_snowflake(data, "target_id")
self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) # type: ignore
self._target_id = utils._get_as_snowflake(data, 'target_id')
def _get_member(self, user_id: int) -> Union[Member, User, None]:
return self.guild.get_member(user_id) or self._users.get(user_id)
def __repr__(self) -> str:
return f"<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>"
return f'<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>'
@utils.cached_property
def created_at(self) -> datetime.datetime:
@ -455,13 +434,9 @@ class AuditLogEntry(Hashable):
return utils.snowflake_time(self.id)
@utils.cached_property
def target(
self,
) -> Union[
Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None
]:
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, Object, None]:
try:
converter = getattr(self, "_convert_target_" + self.action.target_type)
converter = getattr(self, '_convert_target_' + self.action.target_type)
except AttributeError:
return Object(id=self._target_id)
else:
@ -507,11 +482,11 @@ class AuditLogEntry(Hashable):
changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after
fake_payload = {
"max_age": changeset.max_age,
"max_uses": changeset.max_uses,
"code": changeset.code,
"temporary": changeset.temporary,
"uses": changeset.uses,
'max_age': changeset.max_age,
'max_uses': changeset.max_uses,
'code': changeset.code,
'temporary': changeset.temporary,
'uses': changeset.uses,
}
obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore
@ -526,12 +501,3 @@ class AuditLogEntry(Hashable):
def _convert_target_message(self, target_id: int) -> Union[Member, User, None]:
return self._get_member(target_id)
def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]:
return self.guild.get_stage_instance(target_id) or Object(id=target_id)
def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]:
return self._state.get_sticker(target_id) or Object(id=target_id)
def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]:
return self.guild.get_thread(target_id) or Object(id=target_id)

View File

@ -22,19 +22,14 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import time
import random
from typing import Callable, Generic, Literal, TypeVar, overload, Union
T = TypeVar("T", bool, Literal[True], Literal[False])
__all__ = (
'ExponentialBackoff',
)
__all__ = ("ExponentialBackoff",)
class ExponentialBackoff(Generic[T]):
class ExponentialBackoff:
"""An implementation of the exponential backoff algorithm
Provides a convenient interface to implement an exponential backoff
@ -56,33 +51,21 @@ class ExponentialBackoff(Generic[T]):
number in between may be returned.
"""
def __init__(self, base: int = 1, *, integral: T = False):
self._base: int = base
def __init__(self, base=1, *, integral=False):
self._base = base
self._exp: int = 0
self._max: int = 10
self._reset_time: int = base * 2 ** 11
self._last_invocation: float = time.monotonic()
self._exp = 0
self._max = 10
self._reset_time = base * 2 ** 11
self._last_invocation = time.monotonic()
# Use our own random instance to avoid messing with global one
rand = random.Random()
rand.seed()
self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore
self._randfunc = rand.randrange if integral else rand.uniform
@overload
def delay(self: ExponentialBackoff[Literal[False]]) -> float:
...
@overload
def delay(self: ExponentialBackoff[Literal[True]]) -> int:
...
@overload
def delay(self: ExponentialBackoff[bool]) -> Union[int, float]:
...
def delay(self) -> Union[int, float]:
def delay(self):
"""Compute the next delay
Returns the next delay to wait according to the exponential

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -35,11 +35,11 @@ from typing import (
)
__all__ = (
"Colour",
"Color",
'Colour',
'Color',
)
CT = TypeVar("CT", bound="Colour")
CT = TypeVar('CT', bound='Colour')
class Colour:
@ -76,16 +76,16 @@ class Colour:
The raw integer colour value.
"""
__slots__ = ("value",)
__slots__ = ('value',)
def __init__(self, value: int):
def __init__(self, value):
if not isinstance(value, int):
raise TypeError(f"Expected int parameter, received {value.__class__.__name__} instead.")
raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.')
self.value: int = value
def _get_byte(self, byte: int) -> int:
return (self.value >> (8 * byte)) & 0xFF
return (self.value >> (8 * byte)) & 0xff
def __eq__(self, other: Any) -> bool:
return isinstance(other, Colour) and self.value == other.value
@ -94,13 +94,13 @@ class Colour:
return not self.__eq__(other)
def __str__(self) -> str:
return f"#{self.value:0>6x}"
return f'#{self.value:0>6x}'
def __int__(self) -> int:
return self.value
def __repr__(self) -> str:
return f"<Colour value={self.value}>"
return f'<Colour value={self.value}>'
def __hash__(self) -> int:
return hash(self.value)
@ -164,35 +164,27 @@ class Colour:
@classmethod
def teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``."""
return cls(0x1ABC9C)
return cls(0x1abc9c)
@classmethod
def dark_teal(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x11806a``."""
return cls(0x11806A)
@classmethod
def brand_green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x57F287``.
.. versionadded:: 2.0
"""
return cls(0x57F287)
return cls(0x11806a)
@classmethod
def green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``."""
return cls(0x2ECC71)
return cls(0x2ecc71)
@classmethod
def dark_green(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``."""
return cls(0x1F8B4C)
return cls(0x1f8b4c)
@classmethod
def blue(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x3498db``."""
return cls(0x3498DB)
return cls(0x3498db)
@classmethod
def dark_blue(cls: Type[CT]) -> CT:
@ -202,100 +194,85 @@ class Colour:
@classmethod
def purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``."""
return cls(0x9B59B6)
return cls(0x9b59b6)
@classmethod
def dark_purple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x71368a``."""
return cls(0x71368A)
return cls(0x71368a)
@classmethod
def magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe91e63``."""
return cls(0xE91E63)
return cls(0xe91e63)
@classmethod
def dark_magenta(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xad1457``."""
return cls(0xAD1457)
return cls(0xad1457)
@classmethod
def gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``."""
return cls(0xF1C40F)
return cls(0xf1c40f)
@classmethod
def dark_gold(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``."""
return cls(0xC27C0E)
return cls(0xc27c0e)
@classmethod
def orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe67e22``."""
return cls(0xE67E22)
return cls(0xe67e22)
@classmethod
def dark_orange(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xa84300``."""
return cls(0xA84300)
@classmethod
def brand_red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xED4245``.
.. versionadded:: 2.0
"""
return cls(0xED4245)
return cls(0xa84300)
@classmethod
def red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``."""
return cls(0xE74C3C)
@classmethod
def nitro_booster(cls):
"""A factory method that returns a :class:`Colour` with a value of ``0xf47fff``.
.. versionadded:: 2.0"""
return cls(0xF47FFF)
return cls(0xe74c3c)
@classmethod
def dark_red(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x992d22``."""
return cls(0x992D22)
return cls(0x992d22)
@classmethod
def lighter_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``."""
return cls(0x95A5A6)
return cls(0x95a5a6)
lighter_gray = lighter_grey
@classmethod
def dark_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x607d8b``."""
return cls(0x607D8B)
return cls(0x607d8b)
dark_gray = dark_grey
@classmethod
def light_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x979c9f``."""
return cls(0x979C9F)
return cls(0x979c9f)
light_gray = light_grey
@classmethod
def darker_grey(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x546e7a``."""
return cls(0x546E7A)
return cls(0x546e7a)
darker_gray = darker_grey
@classmethod
def og_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x7289da``."""
return cls(0x7289DA)
return cls(0x7289da)
@classmethod
def blurple(cls: Type[CT]) -> CT:
@ -305,7 +282,7 @@ class Colour:
@classmethod
def greyple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x99aab5``."""
return cls(0x99AAB5)
return cls(0x99aab5)
@classmethod
def dark_theme(cls: Type[CT]) -> CT:
@ -332,14 +309,5 @@ class Colour:
"""
return cls(0xFEE75C)
@classmethod
def dark_blurple(cls: Type[CT]) -> CT:
"""A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``.
This is the original Dark Blurple branding.
.. versionadded:: 2.0
"""
return cls(0x4E5D94)
Color = Colour

View File

@ -41,14 +41,14 @@ if TYPE_CHECKING:
__all__ = (
"Component",
"ActionRow",
"Button",
"SelectMenu",
"SelectOption",
'Component',
'ActionRow',
'Button',
'SelectMenu',
'SelectOption',
)
C = TypeVar("C", bound="Component")
C = TypeVar('C', bound='Component')
class Component:
@ -70,14 +70,14 @@ class Component:
The type of component.
"""
__slots__: Tuple[str, ...] = ("type",)
__slots__: Tuple[str, ...] = ('type',)
__repr_info__: ClassVar[Tuple[str, ...]]
type: ComponentType
def __repr__(self) -> str:
attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__)
return f"<{self.__class__.__name__} {attrs}>"
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__)
return f'<{self.__class__.__name__} {attrs}>'
@classmethod
def _raw_construct(cls: Type[C], **kwargs) -> C:
@ -112,18 +112,18 @@ class ActionRow(Component):
The children components that this holds, if any.
"""
__slots__: Tuple[str, ...] = ("children",)
__slots__: Tuple[str, ...] = ('children',)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data["type"])
self.children: List[Component] = [_component_factory(d) for d in data.get("components", [])]
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])]
def to_dict(self) -> ActionRowPayload:
return {
"type": int(self.type),
"components": [child.to_dict() for child in self.children],
'type': int(self.type),
'components': [child.to_dict() for child in self.children],
} # type: ignore
@ -132,16 +132,11 @@ class Button(Component):
This inherits from :class:`Component`.
.. note::
The user constructible and usable type to create a button is :class:`discord.ui.Button`
not this one.
.. versionadded:: 2.0
Attributes
-----------
style: :class:`.ButtonStyle`
style: :class:`ComponentButtonStyle`
The style of the button.
custom_id: Optional[:class:`str`]
The ID of the button that gets received during an interaction.
@ -157,44 +152,44 @@ class Button(Component):
"""
__slots__: Tuple[str, ...] = (
"style",
"custom_id",
"url",
"disabled",
"label",
"emoji",
'style',
'custom_id',
'url',
'disabled',
'label',
'emoji',
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ButtonComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data["type"])
self.style: ButtonStyle = try_enum(ButtonStyle, data["style"])
self.custom_id: Optional[str] = data.get("custom_id")
self.url: Optional[str] = data.get("url")
self.disabled: bool = data.get("disabled", False)
self.label: Optional[str] = data.get("label")
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
self.custom_id: Optional[str] = data.get('custom_id')
self.url: Optional[str] = data.get('url')
self.disabled: bool = data.get('disabled', False)
self.label: Optional[str] = data.get('label')
self.emoji: Optional[PartialEmoji]
try:
self.emoji = PartialEmoji.from_dict(data["emoji"])
self.emoji = PartialEmoji.from_dict(data['emoji'])
except KeyError:
self.emoji = None
def to_dict(self) -> ButtonComponentPayload:
payload = {
"type": 2,
"style": int(self.style),
"label": self.label,
"disabled": self.disabled,
'type': 2,
'style': int(self.style),
'label': self.label,
'disabled': self.disabled,
}
if self.custom_id:
payload["custom_id"] = self.custom_id
payload['custom_id'] = self.custom_id
if self.url:
payload["url"] = self.url
payload['url'] = self.url
if self.emoji:
payload["emoji"] = self.emoji.to_dict()
payload['emoji'] = self.emoji.to_dict()
return payload # type: ignore
@ -205,11 +200,6 @@ class SelectMenu(Component):
A select menu is functionally the same as a dropdown, however
on mobile it renders a bit differently.
.. note::
The user constructible and usable type to create a select menu is
:class:`discord.ui.Select` not this one.
.. versionadded:: 2.0
Attributes
@ -226,42 +216,37 @@ class SelectMenu(Component):
Defaults to 1 and must be between 1 and 25.
options: List[:class:`SelectOption`]
A list of options that can be selected in this menu.
disabled: :class:`bool`
Whether the select is disabled or not.
"""
__slots__: Tuple[str, ...] = (
"custom_id",
"placeholder",
"min_values",
"max_values",
"options",
"disabled",
'custom_id',
'placeholder',
'min_values',
'max_values',
'options',
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: SelectMenuPayload):
self.type = ComponentType.select
self.custom_id: str = data["custom_id"]
self.placeholder: Optional[str] = data.get("placeholder")
self.min_values: int = data.get("min_values", 1)
self.max_values: int = data.get("max_values", 1)
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get("options", [])]
self.disabled: bool = data.get("disabled", False)
self.custom_id: str = data['custom_id']
self.placeholder: Optional[str] = data.get('placeholder')
self.min_values: int = data.get('min_values', 1)
self.max_values: int = data.get('max_values', 1)
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
def to_dict(self) -> SelectMenuPayload:
payload: SelectMenuPayload = {
"type": self.type.value,
"custom_id": self.custom_id,
"min_values": self.min_values,
"max_values": self.max_values,
"options": [op.to_dict() for op in self.options],
"disabled": self.disabled,
'type': self.type.value,
'custom_id': self.custom_id,
'min_values': self.min_values,
'max_values': self.max_values,
'options': [op.to_dict() for op in self.options],
}
if self.placeholder:
payload["placeholder"] = self.placeholder
payload['placeholder'] = self.placeholder
return payload
@ -277,14 +262,14 @@ class SelectOption:
-----------
label: :class:`str`
The label of the option. This is displayed to users.
Can only be up to 100 characters.
Can only be up to 25 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not provided when constructed then it defaults to the
label. Can only be up to 100 characters.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 100 characters.
Can only be up to 50 characters.
emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]]
The emoji of the option, if available.
default: :class:`bool`
@ -292,11 +277,11 @@ class SelectOption:
"""
__slots__: Tuple[str, ...] = (
"label",
"value",
"description",
"emoji",
"default",
'label',
'value',
'description',
'emoji',
'default',
)
def __init__(
@ -318,60 +303,50 @@ class SelectOption:
elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial()
else:
raise TypeError(f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}")
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji = emoji
self.default = default
def __repr__(self) -> str:
return (
f"<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} "
f"emoji={self.emoji!r} default={self.default!r}>"
f'<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} '
f'emoji={self.emoji!r} default={self.default!r}>'
)
def __str__(self) -> str:
if self.emoji:
base = f"{self.emoji} {self.label}"
else:
base = self.label
if self.description:
return f"{base}\n{self.description}"
return base
@classmethod
def from_dict(cls, data: SelectOptionPayload) -> SelectOption:
try:
emoji = PartialEmoji.from_dict(data["emoji"])
emoji = PartialEmoji.from_dict(data['emoji'])
except KeyError:
emoji = None
return cls(
label=data["label"],
value=data["value"],
description=data.get("description"),
label=data['label'],
value=data['value'],
description=data.get('description'),
emoji=emoji,
default=data.get("default", False),
default=data.get('default', False),
)
def to_dict(self) -> SelectOptionPayload:
payload: SelectOptionPayload = {
"label": self.label,
"value": self.value,
"default": self.default,
'label': self.label,
'value': self.value,
'default': self.default,
}
if self.emoji:
payload["emoji"] = self.emoji.to_dict() # type: ignore
payload['emoji'] = self.emoji.to_dict() # type: ignore
if self.description:
payload["description"] = self.description
payload['description'] = self.description
return payload
def _component_factory(data: ComponentPayload) -> Component:
component_type = data["type"]
component_type = data['type']
if component_type == 1:
return ActionRow(data)
elif component_type == 2:

View File

@ -22,35 +22,25 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING, TypeVar, Optional, Type
if TYPE_CHECKING:
from .abc import Messageable
__all__ = (
'Typing',
)
from types import TracebackType
TypingT = TypeVar("TypingT", bound="Typing")
__all__ = ("Typing",)
def _typing_done_callback(fut: asyncio.Future) -> None:
def _typing_done_callback(fut):
# just retrieve any exception and call it a day
try:
fut.exception()
except (asyncio.CancelledError, Exception):
pass
class Typing:
def __init__(self, messageable: Messageable) -> None:
self.loop: asyncio.AbstractEventLoop = messageable._state.loop
self.messageable: Messageable = messageable
def __init__(self, messageable):
self.loop = messageable._state.loop
self.messageable = messageable
async def do_typing(self) -> None:
async def do_typing(self):
try:
channel = self._channel
except AttributeError:
@ -62,28 +52,18 @@ class Typing:
await typing(channel.id)
await asyncio.sleep(5)
def __enter__(self: TypingT) -> TypingT:
self.task: asyncio.Task = self.loop.create_task(self.do_typing())
def __enter__(self):
self.task = asyncio.ensure_future(self.do_typing(), loop=self.loop)
self.task.add_done_callback(_typing_done_callback)
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
def __exit__(self, exc_type, exc, tb):
self.task.cancel()
async def __aenter__(self: TypingT) -> TypingT:
async def __aenter__(self):
self._channel = channel = await self.messageable._get_channel()
await channel._state.http.send_typing(channel.id)
return self.__enter__()
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
async def __aexit__(self, exc_type, exc, tb):
self.task.cancel()

View File

@ -25,12 +25,14 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Type, TypeVar, Union
from typing import Any, Dict, Final, List, Protocol, TYPE_CHECKING, Type, TypeVar, Union
from . import utils
from .colour import Colour
__all__ = ("Embed",)
__all__ = (
'Embed',
)
class _EmptyEmbed:
@ -38,7 +40,7 @@ class _EmptyEmbed:
return False
def __repr__(self) -> str:
return "Embed.Empty"
return 'Embed.Empty'
def __len__(self) -> int:
return 0
@ -55,19 +57,19 @@ class EmbedProxy:
return len(self.__dict__)
def __repr__(self) -> str:
inner = ", ".join((f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")))
return f"EmbedProxy({inner})"
inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_')))
return f'EmbedProxy({inner})'
def __getattr__(self, attr: str) -> _EmptyEmbed:
return EmptyEmbed
E = TypeVar("E", bound="Embed")
E = TypeVar('E', bound='Embed')
if TYPE_CHECKING:
from discord.types.embed import Embed as EmbedData, EmbedType
T = TypeVar("T")
T = TypeVar('T')
MaybeEmpty = Union[T, _EmptyEmbed]
class _EmbedFooterProxy(Protocol):
@ -155,19 +157,19 @@ class Embed:
"""
__slots__ = (
"title",
"url",
"type",
"_timestamp",
"_colour",
"_footer",
"_image",
"_thumbnail",
"_video",
"_provider",
"_author",
"_fields",
"description",
'title',
'url',
'type',
'_timestamp',
'_colour',
'_footer',
'_image',
'_thumbnail',
'_video',
'_provider',
'_author',
'_fields',
'description',
)
Empty: Final = EmptyEmbed
@ -178,7 +180,7 @@ class Embed:
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed,
type: EmbedType = "rich",
type: EmbedType = 'rich',
url: MaybeEmpty[Any] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed,
timestamp: datetime.datetime = None,
@ -200,10 +202,12 @@ class Embed:
self.url = str(self.url)
if timestamp:
if timestamp.tzinfo is None:
timestamp = timestamp.astimezone()
self.timestamp = timestamp
@classmethod
def from_dict(cls: Type[E], data: Mapping[str, Any]) -> E:
def from_dict(cls: Type[E], data: EmbedData) -> E:
"""Converts a :class:`dict` to a :class:`Embed` provided it is in the
format that Discord expects it to be in.
@ -223,10 +227,10 @@ class Embed:
# fill in the basic fields
self.title = data.get("title", EmptyEmbed)
self.type = data.get("type", EmptyEmbed)
self.description = data.get("description", EmptyEmbed)
self.url = data.get("url", EmptyEmbed)
self.title = data.get('title', EmptyEmbed)
self.type = data.get('type', EmptyEmbed)
self.description = data.get('description', EmptyEmbed)
self.url = data.get('url', EmptyEmbed)
if self.title is not EmptyEmbed:
self.title = str(self.title)
@ -240,22 +244,22 @@ class Embed:
# try to fill in the more rich fields
try:
self._colour = Colour(value=data["color"])
self._colour = Colour(value=data['color'])
except KeyError:
pass
try:
self._timestamp = utils.parse_time(data["timestamp"])
self._timestamp = utils.parse_time(data['timestamp'])
except KeyError:
pass
for attr in ("thumbnail", "video", "provider", "author", "fields", "image", "footer"):
for attr in ('thumbnail', 'video', 'provider', 'author', 'fields', 'image', 'footer'):
try:
value = data[attr]
except KeyError:
continue
else:
setattr(self, "_" + attr, value)
setattr(self, '_' + attr, value)
return self
@ -265,11 +269,11 @@ class Embed:
def __len__(self) -> int:
total = len(self.title) + len(self.description)
for field in getattr(self, "_fields", []):
total += len(field["name"]) + len(field["value"])
for field in getattr(self, '_fields', []):
total += len(field['name']) + len(field['value'])
try:
footer_text = self._footer["text"]
footer_text = self._footer['text']
except (AttributeError, KeyError):
pass
else:
@ -280,7 +284,7 @@ class Embed:
except AttributeError:
pass
else:
total += len(author["name"])
total += len(author['name'])
return total
@ -304,7 +308,7 @@ class Embed:
@property
def colour(self) -> MaybeEmpty[Colour]:
return getattr(self, "_colour", EmptyEmbed)
return getattr(self, '_colour', EmptyEmbed)
@colour.setter
def colour(self, value: Union[int, Colour, _EmptyEmbed]): # type: ignore
@ -313,23 +317,17 @@ class Embed:
elif isinstance(value, int):
self._colour = Colour(value=value)
else:
raise TypeError(
f"Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead."
)
raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.')
color = colour
@property
def timestamp(self) -> MaybeEmpty[datetime.datetime]:
return getattr(self, "_timestamp", EmptyEmbed)
return getattr(self, '_timestamp', EmptyEmbed)
@timestamp.setter
def timestamp(self, value: MaybeEmpty[datetime.datetime]):
if isinstance(value, datetime.datetime):
if value.tzinfo is None:
value = value.astimezone()
self._timestamp = value
elif isinstance(value, _EmptyEmbed):
if isinstance(value, (datetime.datetime, _EmptyEmbed)):
self._timestamp = value
else:
raise TypeError(f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead")
@ -342,7 +340,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, "_footer", {})) # type: ignore
return EmbedProxy(getattr(self, '_footer', {})) # type: ignore
def set_footer(self: E, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E:
"""Sets the footer for the embed content.
@ -360,13 +358,13 @@ class Embed:
self._footer = {}
if text is not EmptyEmbed:
self._footer["text"] = str(text)
self._footer['text'] = str(text)
if icon_url is not EmptyEmbed:
self._footer["icon_url"] = str(icon_url)
self._footer['icon_url'] = str(icon_url)
return self
def remove_footer(self: E) -> E:
"""Clears embed's footer information.
@ -381,7 +379,7 @@ class Embed:
pass
return self
@property
def image(self) -> _EmbedMediaProxy:
"""Returns an ``EmbedProxy`` denoting the image contents.
@ -395,21 +393,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, "_image", {})) # type: ignore
@image.setter
def image(self, url: Any):
if url is EmptyEmbed:
del self.image
else:
self._image = {"url": str(url)}
@image.deleter
def image(self):
try:
del self._image
except AttributeError:
pass
return EmbedProxy(getattr(self, '_image', {})) # type: ignore
def set_image(self: E, *, url: MaybeEmpty[Any]) -> E:
"""Sets the image for the embed content.
@ -426,7 +410,16 @@ class Embed:
The source URL for the image. Only HTTP(S) is supported.
"""
self.image = url
if url is EmptyEmbed:
try:
del self._image
except AttributeError:
pass
else:
self._image = {
'url': str(url),
}
return self
@property
@ -442,23 +435,9 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, "_thumbnail", {})) # type: ignore
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
@thumbnail.setter
def thumbnail(self, url: Any):
if url is EmptyEmbed:
del self.thumbnail
else:
self._thumbnail = {"url": str(url)}
@thumbnail.deleter
def thumbnail(self):
try:
del self._thumbnail
except AttributeError:
pass
def set_thumbnail(self, *, url: MaybeEmpty[Any]):
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E:
"""Sets the thumbnail for the embed content.
This function returns the class instance to allow for fluent-style
@ -473,7 +452,16 @@ class Embed:
The source URL for the thumbnail. Only HTTP(S) is supported.
"""
self.thumbnail = url
if url is EmptyEmbed:
try:
del self._thumbnail
except AttributeError:
pass
else:
self._thumbnail = {
'url': str(url),
}
return self
@property
@ -488,7 +476,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, "_video", {})) # type: ignore
return EmbedProxy(getattr(self, '_video', {})) # type: ignore
@property
def provider(self) -> _EmbedProviderProxy:
@ -498,7 +486,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, "_provider", {})) # type: ignore
return EmbedProxy(getattr(self, '_provider', {})) # type: ignore
@property
def author(self) -> _EmbedAuthorProxy:
@ -508,11 +496,9 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return EmbedProxy(getattr(self, "_author", {})) # type: ignore
return EmbedProxy(getattr(self, '_author', {})) # type: ignore
def set_author(
self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed
) -> E:
def set_author(self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E:
"""Sets the author for the embed content.
This function returns the class instance to allow for fluent-style
@ -529,14 +515,14 @@ class Embed:
"""
self._author = {
"name": str(name),
'name': str(name),
}
if url is not EmptyEmbed:
self._author["url"] = str(url)
self._author['url'] = str(url)
if icon_url is not EmptyEmbed:
self._author["icon_url"] = str(icon_url)
self._author['icon_url'] = str(icon_url)
return self
@ -563,7 +549,7 @@ class Embed:
If the attribute has no value then :attr:`Empty` is returned.
"""
return [EmbedProxy(d) for d in getattr(self, "_fields", [])] # type: ignore
return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore
def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E:
"""Adds a field to the embed object.
@ -582,9 +568,9 @@ class Embed:
"""
field = {
"inline": inline,
"name": str(name),
"value": str(value),
'inline': inline,
'name': str(name),
'value': str(value),
}
try:
@ -615,9 +601,9 @@ class Embed:
"""
field = {
"inline": inline,
"name": str(name),
"value": str(value),
'inline': inline,
'name': str(name),
'value': str(value),
}
try:
@ -683,11 +669,11 @@ class Embed:
try:
field = self._fields[index]
except (TypeError, IndexError, AttributeError):
raise IndexError("field index out of range")
raise IndexError('field index out of range')
field["name"] = str(name)
field["value"] = str(value)
field["inline"] = inline
field['name'] = str(name)
field['value'] = str(value)
field['inline'] = inline
return self
def to_dict(self) -> EmbedData:
@ -705,35 +691,35 @@ class Embed:
# deal with basic convenience wrappers
try:
colour = result.pop("colour")
colour = result.pop('colour')
except KeyError:
pass
else:
if colour:
result["color"] = colour.value
result['color'] = colour.value
try:
timestamp = result.pop("timestamp")
timestamp = result.pop('timestamp')
except KeyError:
pass
else:
if timestamp:
if timestamp.tzinfo:
result["timestamp"] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat()
result['timestamp'] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat()
else:
result["timestamp"] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat()
result['timestamp'] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat()
# add in the non raw attribute ones
if self.type:
result["type"] = self.type
result['type'] = self.type
if self.description:
result["description"] = self.description
result['description'] = self.description
if self.url:
result["url"] = self.url
result['url'] = self.url
if self.title:
result["title"] = self.title
result['title'] = self.title
return result # type: ignore

View File

@ -30,7 +30,9 @@ from .utils import SnowflakeList, snowflake_time, MISSING
from .partial_emoji import _EmojiTag, PartialEmoji
from .user import User
__all__ = ("Emoji",)
__all__ = (
'Emoji',
)
if TYPE_CHECKING:
from .types.emoji import Emoji as EmojiPayload
@ -70,10 +72,6 @@ class Emoji(_EmojiTag, AssetMixin):
Returns the emoji rendered for discord.
.. describe:: int(x)
Returns the emoji ID.
Attributes
-----------
name: :class:`str`
@ -96,16 +94,16 @@ class Emoji(_EmojiTag, AssetMixin):
"""
__slots__: Tuple[str, ...] = (
"require_colons",
"animated",
"managed",
"id",
"name",
"_roles",
"guild_id",
"_state",
"user",
"available",
'require_colons',
'animated',
'managed',
'id',
'name',
'_roles',
'guild_id',
'_state',
'user',
'available',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload):
@ -114,14 +112,14 @@ class Emoji(_EmojiTag, AssetMixin):
self._from_data(data)
def _from_data(self, emoji: EmojiPayload):
self.require_colons: bool = emoji.get("require_colons", False)
self.managed: bool = emoji.get("managed", False)
self.id: int = int(emoji["id"]) # type: ignore
self.name: str = emoji["name"] # type: ignore
self.animated: bool = emoji.get("animated", False)
self.available: bool = emoji.get("available", True)
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get("roles", [])))
user = emoji.get("user")
self.require_colons: bool = emoji.get('require_colons', False)
self.managed: bool = emoji.get('managed', False)
self.id: int = int(emoji['id']) # type: ignore
self.name: str = emoji['name'] # type: ignore
self.animated: bool = emoji.get('animated', False)
self.available: bool = emoji.get('available', True)
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', [])))
user = emoji.get('user')
self.user: Optional[User] = User(state=self._state, data=user) if user else None
def _to_partial(self) -> PartialEmoji:
@ -129,21 +127,18 @@ class Emoji(_EmojiTag, AssetMixin):
def __iter__(self) -> Iterator[Tuple[str, Any]]:
for attr in self.__slots__:
if attr[0] != "_":
if attr[0] != '_':
value = getattr(self, attr, None)
if value is not None:
yield (attr, value)
def __str__(self) -> str:
if self.animated:
return f"<a:{self.name}:{self.id}>"
return f"<:{self.name}:{self.id}>"
def __int__(self) -> int:
return self.id
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
def __repr__(self) -> str:
return f"<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>"
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>'
def __eq__(self, other: Any) -> bool:
return isinstance(other, _EmojiTag) and self.id == other.id
@ -162,8 +157,8 @@ class Emoji(_EmojiTag, AssetMixin):
@property
def url(self) -> str:
""":class:`str`: Returns the URL of the emoji."""
fmt = "gif" if self.animated else "png"
return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
fmt = 'gif' if self.animated else 'png'
return f'{Asset.BASE}/emojis/{self.id}.{fmt}'
@property
def roles(self) -> List[Role]:
@ -217,9 +212,7 @@ class Emoji(_EmojiTag, AssetMixin):
await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason)
async def edit(
self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None
) -> Emoji:
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> None:
r"""|coro|
Edits the custom emoji.
@ -227,9 +220,6 @@ class Emoji(_EmojiTag, AssetMixin):
You must have :attr:`~Permissions.manage_emojis` permission to
do this.
.. versionchanged:: 2.0
The newly updated emoji is returned.
Parameters
-----------
name: :class:`str`
@ -245,18 +235,12 @@ class Emoji(_EmojiTag, AssetMixin):
You are not allowed to edit emojis.
HTTPException
An error occurred editing the emoji.
Returns
--------
:class:`Emoji`
The newly updated emoji.
"""
payload = {}
if name is not MISSING:
payload["name"] = name
payload['name'] = name
if roles is not MISSING:
payload["roles"] = [role.id for role in roles]
payload['roles'] = [role.id for role in roles]
data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
return Emoji(guild=self.guild, data=data, state=self._state)
await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)

View File

@ -24,73 +24,57 @@ DEALINGS IN THE SOFTWARE.
import types
from collections import namedtuple
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar
from typing import Any, Dict, Optional, TYPE_CHECKING, Type, TypeVar
__all__ = (
"Enum",
"ChannelType",
"MessageType",
"VoiceRegion",
"SpeakingState",
"VerificationLevel",
"ContentFilter",
"Status",
"DefaultAvatar",
"AuditLogAction",
"AuditLogActionCategory",
"UserFlags",
"ActivityType",
"NotificationLevel",
"TeamMembershipState",
"WebhookType",
"ExpireBehaviour",
"ExpireBehavior",
"StickerType",
"StickerFormatType",
"InviteTarget",
"VideoQualityMode",
"ComponentType",
"ButtonStyle",
"StagePrivacyLevel",
"InteractionType",
"InteractionResponseType",
"NSFWLevel",
"ProtocolURL",
'Enum',
'ChannelType',
'MessageType',
'VoiceRegion',
'SpeakingState',
'VerificationLevel',
'ContentFilter',
'Status',
'DefaultAvatar',
'AuditLogAction',
'AuditLogActionCategory',
'UserFlags',
'ActivityType',
'NotificationLevel',
'TeamMembershipState',
'WebhookType',
'ExpireBehaviour',
'ExpireBehavior',
'StickerType',
'InviteTarget',
'VideoQualityMode',
'ComponentType',
'ButtonStyle',
'StagePrivacyLevel',
'InteractionType',
'InteractionResponseType',
'NSFWLevel',
)
def _create_value_cls(name, comparable):
cls = namedtuple("_EnumValue_" + name, "name value")
cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>"
cls.__str__ = lambda self: f"{name}.{self.name}"
if comparable:
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
def _create_value_cls(name):
cls = namedtuple('_EnumValue_' + name, 'name value')
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>'
cls.__str__ = lambda self: f'{name}.{self.name}'
return cls
def _is_descriptor(obj):
return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')
class EnumMeta(type):
if TYPE_CHECKING:
__name__: ClassVar[str]
_enum_member_names_: ClassVar[List[str]]
_enum_member_map_: ClassVar[Dict[str, Any]]
_enum_value_map_: ClassVar[Dict[Any, Any]]
def __new__(cls, name, bases, attrs, *, comparable: bool = False):
def __new__(cls, name, bases, attrs):
value_mapping = {}
member_mapping = {}
member_names = []
value_cls = _create_value_cls(name, comparable)
value_cls = _create_value_cls(name)
for key, value in list(attrs.items()):
is_descriptor = _is_descriptor(value)
if key[0] == "_" and not is_descriptor:
if key[0] == '_' and not is_descriptor:
continue
# Special case classmethod to just pass through
@ -112,12 +96,12 @@ class EnumMeta(type):
member_mapping[key] = new_value
attrs[key] = new_value
attrs["_enum_value_map_"] = value_mapping
attrs["_enum_member_map_"] = member_mapping
attrs["_enum_member_names_"] = member_names
attrs["_enum_value_cls_"] = value_cls
attrs['_enum_value_map_'] = value_mapping
attrs['_enum_member_map_'] = member_mapping
attrs['_enum_member_names_'] = member_names
attrs['_enum_value_cls_'] = value_cls
actual_cls = super().__new__(cls, name, bases, attrs)
value_cls._actual_enum_cls_ = actual_cls # type: ignore
value_cls._actual_enum_cls_ = actual_cls
return actual_cls
def __iter__(cls):
@ -130,7 +114,7 @@ class EnumMeta(type):
return len(cls._enum_member_names_)
def __repr__(cls):
return f"<enum {cls.__name__}>"
return f'<enum {cls.__name__}>'
@property
def __members__(cls):
@ -146,10 +130,10 @@ class EnumMeta(type):
return cls._enum_member_map_[key]
def __setattr__(cls, name, value):
raise TypeError("Enums are immutable.")
raise TypeError('Enums are immutable.')
def __delattr__(cls, attr):
raise TypeError("Enums are immutable")
raise TypeError('Enums are immutable')
def __instancecheck__(self, instance):
# isinstance(x, Y)
@ -159,11 +143,9 @@ class EnumMeta(type):
except AttributeError:
return False
if TYPE_CHECKING:
from enum import Enum
else:
class Enum(metaclass=EnumMeta):
@classmethod
def try_value(cls, value):
@ -172,84 +154,80 @@ else:
except (KeyError, TypeError):
return value
class ChannelType(Enum):
text = 0
private = 1
voice = 2
group = 3
category = 4
news = 5
store = 6
news_thread = 10
public_thread = 11
text = 0
private = 1
voice = 2
group = 3
category = 4
news = 5
store = 6
news_thread = 10
public_thread = 11
private_thread = 12
stage_voice = 13
stage_voice = 13
def __str__(self):
return self.name
class MessageType(Enum):
default = 0
recipient_add = 1
recipient_remove = 2
call = 3
channel_name_change = 4
channel_icon_change = 5
pins_add = 6
new_member = 7
premium_guild_subscription = 8
premium_guild_tier_1 = 9
premium_guild_tier_2 = 10
premium_guild_tier_3 = 11
channel_follow_add = 12
guild_stream = 13
guild_discovery_disqualified = 14
guild_discovery_requalified = 15
default = 0
recipient_add = 1
recipient_remove = 2
call = 3
channel_name_change = 4
channel_icon_change = 5
pins_add = 6
new_member = 7
premium_guild_subscription = 8
premium_guild_tier_1 = 9
premium_guild_tier_2 = 10
premium_guild_tier_3 = 11
channel_follow_add = 12
guild_stream = 13
guild_discovery_disqualified = 14
guild_discovery_requalified = 15
guild_discovery_grace_period_initial_warning = 16
guild_discovery_grace_period_final_warning = 17
thread_created = 18
reply = 19
application_command = 20
thread_starter_message = 21
guild_invite_reminder = 22
guild_discovery_grace_period_final_warning = 17
thread_created = 18
reply = 19
application_command = 20
thread_starter_message = 21
guild_invite_reminder = 22
class VoiceRegion(Enum):
us_west = "us-west"
us_east = "us-east"
us_south = "us-south"
us_central = "us-central"
eu_west = "eu-west"
eu_central = "eu-central"
singapore = "singapore"
london = "london"
sydney = "sydney"
amsterdam = "amsterdam"
frankfurt = "frankfurt"
brazil = "brazil"
hongkong = "hongkong"
russia = "russia"
japan = "japan"
southafrica = "southafrica"
south_korea = "south-korea"
india = "india"
europe = "europe"
dubai = "dubai"
vip_us_east = "vip-us-east"
vip_us_west = "vip-us-west"
vip_amsterdam = "vip-amsterdam"
us_west = 'us-west'
us_east = 'us-east'
us_south = 'us-south'
us_central = 'us-central'
eu_west = 'eu-west'
eu_central = 'eu-central'
singapore = 'singapore'
london = 'london'
sydney = 'sydney'
amsterdam = 'amsterdam'
frankfurt = 'frankfurt'
brazil = 'brazil'
hongkong = 'hongkong'
russia = 'russia'
japan = 'japan'
southafrica = 'southafrica'
south_korea = 'south-korea'
india = 'india'
europe = 'europe'
dubai = 'dubai'
vip_us_east = 'vip-us-east'
vip_us_west = 'vip-us-west'
vip_amsterdam = 'vip-amsterdam'
def __str__(self):
return self.value
class SpeakingState(Enum):
none = 0
voice = 1
none = 0
voice = 1
soundshare = 2
priority = 4
priority = 4
def __str__(self):
return self.name
@ -257,64 +235,56 @@ class SpeakingState(Enum):
def __int__(self):
return self.value
class VerificationLevel(Enum, comparable=True):
none = 0
low = 1
medium = 2
high = 3
class VerificationLevel(Enum):
none = 0
low = 1
medium = 2
high = 3
highest = 4
def __str__(self):
return self.name
class ContentFilter(Enum, comparable=True):
disabled = 0
no_role = 1
class ContentFilter(Enum):
disabled = 0
no_role = 1
all_members = 2
def __str__(self):
return self.name
class Status(Enum):
online = "online"
offline = "offline"
idle = "idle"
dnd = "dnd"
do_not_disturb = "dnd"
invisible = "invisible"
online = 'online'
offline = 'offline'
idle = 'idle'
dnd = 'dnd'
do_not_disturb = 'dnd'
invisible = 'invisible'
def __str__(self):
return self.value
class DefaultAvatar(Enum):
blurple = 0
grey = 1
gray = 1
green = 2
orange = 3
red = 4
grey = 1
gray = 1
green = 2
orange = 3
red = 4
def __str__(self):
return self.name
class NotificationLevel(Enum, comparable=True):
all_messages = 0
class NotificationLevel(Enum):
all_messages = 0
only_mentions = 1
class AuditLogActionCategory(Enum):
create = 1
delete = 2
update = 3
class AuditLogAction(Enum):
# fmt: off
guild_update = 1
channel_create = 10
channel_update = 11
@ -353,17 +323,9 @@ class AuditLogAction(Enum):
stage_instance_create = 83
stage_instance_update = 84
stage_instance_delete = 85
sticker_create = 90
sticker_update = 91
sticker_delete = 92
thread_create = 110
thread_update = 111
thread_delete = 112
# fmt: on
@property
def category(self) -> Optional[AuditLogActionCategory]:
# fmt: off
lookup: Dict[AuditLogAction, Optional[AuditLogActionCategory]] = {
AuditLogAction.guild_update: AuditLogActionCategory.update,
AuditLogAction.channel_create: AuditLogActionCategory.create,
@ -403,48 +365,36 @@ class AuditLogAction(Enum):
AuditLogAction.stage_instance_create: AuditLogActionCategory.create,
AuditLogAction.stage_instance_update: AuditLogActionCategory.update,
AuditLogAction.stage_instance_delete: AuditLogActionCategory.delete,
AuditLogAction.sticker_create: AuditLogActionCategory.create,
AuditLogAction.sticker_update: AuditLogActionCategory.update,
AuditLogAction.sticker_delete: AuditLogActionCategory.delete,
AuditLogAction.thread_create: AuditLogActionCategory.create,
AuditLogAction.thread_update: AuditLogActionCategory.update,
AuditLogAction.thread_delete: AuditLogActionCategory.delete,
}
# fmt: on
return lookup[self]
@property
def target_type(self) -> Optional[str]:
v = self.value
if v == -1:
return "all"
return 'all'
elif v < 10:
return "guild"
return 'guild'
elif v < 20:
return "channel"
return 'channel'
elif v < 30:
return "user"
return 'user'
elif v < 40:
return "role"
return 'role'
elif v < 50:
return "invite"
return 'invite'
elif v < 60:
return "webhook"
return 'webhook'
elif v < 70:
return "emoji"
return 'emoji'
elif v == 73:
return "channel"
return 'channel'
elif v < 80:
return "message"
return 'message'
elif v < 83:
return "integration"
return 'integration'
elif v < 90:
return "stage_instance"
elif v < 93:
return "sticker"
elif v < 113:
return "thread"
return 'stage_instance'
class UserFlags(Enum):
staff = 1
@ -465,7 +415,6 @@ class UserFlags(Enum):
verified_bot_developer = 131072
discord_certified_moderator = 262144
class ActivityType(Enum):
unknown = -1
playing = 0
@ -478,60 +427,35 @@ class ActivityType(Enum):
def __int__(self):
return self.value
class TeamMembershipState(Enum):
invited = 1
accepted = 2
class WebhookType(Enum):
incoming = 1
channel_follower = 2
application = 3
class ExpireBehaviour(Enum):
remove_role = 0
kick = 1
ExpireBehavior = ExpireBehaviour
class StickerType(Enum):
standard = 1
guild = 2
class StickerFormatType(Enum):
png = 1
apng = 2
lottie = 3
@property
def file_extension(self) -> str:
# fmt: off
lookup: Dict[StickerFormatType, str] = {
StickerFormatType.png: 'png',
StickerFormatType.apng: 'png',
StickerFormatType.lottie: 'json',
}
# fmt: on
return lookup[self]
class InviteTarget(Enum):
unknown = 0
stream = 1
embedded_application = 2
class InteractionType(Enum):
ping = 1
application_command = 2
component = 3
application_command_autocomplete = 4
class InteractionResponseType(Enum):
pong = 1
@ -540,9 +464,7 @@ class InteractionResponseType(Enum):
channel_message = 4 # (with source)
deferred_channel_message = 5 # (with source)
deferred_message_update = 6 # for components
message_update = 7 # for components
application_command_autocomplete_result = 8
message_update = 7 # for components
class VideoQualityMode(Enum):
auto = 1
@ -551,7 +473,6 @@ class VideoQualityMode(Enum):
def __int__(self):
return self.value
class ComponentType(Enum):
action_row = 1
button = 2
@ -560,7 +481,6 @@ class ComponentType(Enum):
def __int__(self):
return self.value
class ButtonStyle(Enum):
primary = 1
secondary = 2
@ -571,105 +491,30 @@ class ButtonStyle(Enum):
# Aliases
blurple = 1
grey = 2
gray = 2
green = 3
red = 4
url = 5
def __int__(self):
return self.value
class StagePrivacyLevel(Enum):
public = 1
closed = 2
guild_only = 2
class NSFWLevel(Enum, comparable=True):
class NSFWLevel(Enum):
default = 0
explicit = 1
safe = 2
age_restricted = 3
class ProtocolURL(Enum):
# General
home = "discord://-/channels/@me/"
nitro = "discord://-/store"
apps = "discord://-/apps" # Breaks the client on windows (Shows download links for different OS)
guild_discovery = "discord://-/guild-discovery"
guild_create = "discord://-/guilds/create"
guild_invite = "discord://-/invite/{invite_code}"
# Settings
account_settings = "discord://-/settings/account"
profile_settings = "discord://-/settings/profile-customization"
privacy_settings = "discord://-/settings/privacy-and-safety"
safety_settings = "discord://-/settings/privacy-and-safety" # Alias
authorized_apps_settings = "discord://-/settings/authorized-apps"
connections_settings = "discord://-/settings/connections"
nitro_settings = "discord://-/settings/premium" # Same as store, but inside of settings
guild_premium_subscription = "discord://-/settings/premium-guild-subscription"
subscription_settings = "discord://-/settings/subscriptions"
gift_inventory_settings = "discord://-/settings/inventory"
billing_settings = "discord://-/settings/billing"
appearance_settings = "discord://-/settings/appearance"
accessibility_settings = "discord://-/settings/accessibility"
voice_video_settings = "discord://-/settings/voice"
text_images_settings = "discord://-/settings/text"
notifications_settings = "discord://-/settings/notifications"
keybinds_settings = "discord://-/settings/keybinds"
language_settings = "discord://-/settings/locale"
windows_settings = "discord://-/settings/windows" # Doesnt work if used on wrong platform
linux_settings = "discord://-/settings/linux" # Doesnt work if used on wrong platform
streamer_mode_settings = "discord://-/settings/streamer-mode"
advanced_settings = "discord://-/settings/advanced"
activity_status_settings = "discord://-/settings/activity-status"
game_overlay_settings = "discord://-/settings/overlay"
hypesquad_settings = "discord://-/settings/hypesquad-online"
changelogs = "discord://-/settings/changelogs"
# Doesn't work if you don't have it actually activated. Just blank screen.
experiments = "discord://-/settings/experiments"
developer_options = "discord://-/settings/developer-options" # Same as experiments
hotspot_options = "discord://-/settings/hotspot-options" # Same as experiments
# Users, Guilds, and DMs
user_profile = "discord://-/users/{user_id}"
dm_channel = "discord://-/channels/@me/{channel_id}"
dm_message = "discord://-/channels/@me/{channel_id}/{message_id}"
guild_channel = "discord://-/channels/{guild_id}/{channel_id}"
guild_message = "discord://-/channels/{guild_id}/{channel_id}/{message_id}"
guild_membership_screening = "discord://-/member-verification/{guild_id}"
# Library
games_library = "discord://-/library"
library_settings = "discord://-/library/settings"
def __str__(self) -> str:
return self.value
def format(self, **kwargs: Any) -> str:
return self.value.format(**kwargs)
T = TypeVar("T")
T = TypeVar('T')
def create_unknown_value(cls: Type[T], val: Any) -> T:
value_cls = cls._enum_value_cls_ # type: ignore
name = f"unknown_{val}"
value_cls = cls._enum_value_cls_ # type: ignore
name = f'unknown_{val}'
return value_cls(name=name, value=val)
def try_enum(cls: Type[T], val: Any) -> T:
"""A function that tries to turn the value into enum ``cls``.
@ -677,6 +522,6 @@ def try_enum(cls: Type[T], val: Any) -> T:
"""
try:
return cls._enum_value_map_[val] # type: ignore
return cls._enum_value_map_[val] # type: ignore
except (KeyError, TypeError, AttributeError):
return create_unknown_value(cls, val)

View File

@ -22,91 +22,67 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Dict, List, Optional, TYPE_CHECKING, Any, Tuple, Union
if TYPE_CHECKING:
from aiohttp import ClientResponse, ClientWebSocketResponse
try:
from requests import Response
_ResponseType = Union[ClientResponse, Response]
except ModuleNotFoundError:
_ResponseType = ClientResponse
from .interactions import Interaction
__all__ = (
"DiscordException",
"ClientException",
"NoMoreItems",
"GatewayNotFound",
"HTTPException",
"Forbidden",
"NotFound",
"DiscordServerError",
"InvalidData",
"InvalidArgument",
"LoginFailure",
"ConnectionClosed",
"PrivilegedIntentsRequired",
"InteractionResponded",
'DiscordException',
'ClientException',
'NoMoreItems',
'GatewayNotFound',
'HTTPException',
'Forbidden',
'NotFound',
'DiscordServerError',
'InvalidData',
'InvalidArgument',
'LoginFailure',
'ConnectionClosed',
'PrivilegedIntentsRequired',
)
class DiscordException(Exception):
"""Base exception class for discord.py
Ideally speaking, this could be caught to handle any exceptions raised from this library.
Ideally speaking, this could be caught to handle any exceptions thrown from this library.
"""
pass
class ClientException(DiscordException):
"""Exception that's raised when an operation in the :class:`Client` fails.
"""Exception that's thrown when an operation in the :class:`Client` fails.
These are usually for exceptions that happened due to user input.
"""
pass
class NoMoreItems(DiscordException):
"""Exception that is raised when an async iteration operation has no more items."""
"""Exception that is thrown when an async iteration operation has no more
items."""
pass
class GatewayNotFound(DiscordException):
"""An exception that is raised when the gateway for Discord could not be found"""
"""An exception that is usually thrown when the gateway hub
for the :class:`Client` websocket is not found."""
def __init__(self):
message = "The gateway to connect to discord was not found."
message = 'The gateway to connect to discord was not found.'
super().__init__(message)
def _flatten_error_dict(d: Dict[str, Any], key: str = "") -> Dict[str, str]:
items: List[Tuple[str, str]] = []
def flatten_error_dict(d, key=''):
items = []
for k, v in d.items():
new_key = key + "." + k if key else k
new_key = key + '.' + k if key else k
if isinstance(v, dict):
try:
_errors: List[Dict[str, Any]] = v["_errors"]
_errors = v['_errors']
except KeyError:
items.extend(_flatten_error_dict(v, new_key).items())
items.extend(flatten_error_dict(v, new_key).items())
else:
items.append((new_key, " ".join(x.get("message", "") for x in _errors)))
items.append((new_key, ' '.join(x.get('message', '') for x in _errors)))
else:
items.append((new_key, v))
return dict(items)
class HTTPException(DiscordException):
"""Exception that's raised when an HTTP request operation fails.
"""Exception that's thrown when an HTTP request operation fails.
Attributes
------------
@ -123,92 +99,77 @@ class HTTPException(DiscordException):
The Discord specific error code for the failure.
"""
def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]):
self.response: _ResponseType = response
self.status: int = response.status # type: ignore
self.code: int
self.text: str
def __init__(self, response, message):
self.response = response
self.status = response.status
if isinstance(message, dict):
self.code = message.get("code", 0)
base = message.get("message", "")
errors = message.get("errors")
self.code = message.get('code', 0)
base = message.get('message', '')
errors = message.get('errors')
if errors:
errors = _flatten_error_dict(errors)
helpful = "\n".join("In %s: %s" % t for t in errors.items())
self.text = base + "\n" + helpful
errors = flatten_error_dict(errors)
helpful = '\n'.join('In %s: %s' % t for t in errors.items())
self.text = base + '\n' + helpful
else:
self.text = base
else:
self.text = message or ""
self.text = message
self.code = 0
fmt = "{0.status} {0.reason} (error code: {1})"
fmt = '{0.status} {0.reason} (error code: {1})'
if len(self.text):
fmt += ": {2}"
fmt += ': {2}'
super().__init__(fmt.format(self.response, self.code, self.text))
class Forbidden(HTTPException):
"""Exception that's raised for when status code 403 occurs.
"""Exception that's thrown for when status code 403 occurs.
Subclass of :exc:`HTTPException`
"""
pass
class NotFound(HTTPException):
"""Exception that's raised for when status code 404 occurs.
"""Exception that's thrown for when status code 404 occurs.
Subclass of :exc:`HTTPException`
"""
pass
class DiscordServerError(HTTPException):
"""Exception that's raised for when a 500 range status code occurs.
"""Exception that's thrown for when a 500 range status code occurs.
Subclass of :exc:`HTTPException`.
.. versionadded:: 1.5
"""
pass
class InvalidData(ClientException):
"""Exception that's raised when the library encounters unknown
or invalid data from Discord.
"""
pass
class InvalidArgument(ClientException):
"""Exception that's raised when an argument to a function
"""Exception that's thrown when an argument to a function
is invalid some way (e.g. wrong value or wrong type).
This could be considered the analogous of ``ValueError`` and
``TypeError`` except inherited from :exc:`ClientException` and thus
:exc:`DiscordException`.
"""
pass
class LoginFailure(ClientException):
"""Exception that's raised when the :meth:`Client.login` function
"""Exception that's thrown when the :meth:`Client.login` function
fails to log you in from improper credentials or some other misc.
failure.
"""
pass
class ConnectionClosed(ClientException):
"""Exception that's raised when the gateway connection is
"""Exception that's thrown when the gateway connection is
closed for reasons that could not be handled internally.
Attributes
@ -220,19 +181,17 @@ class ConnectionClosed(ClientException):
shard_id: Optional[:class:`int`]
The shard ID that got closed if applicable.
"""
def __init__(self, socket: ClientWebSocketResponse, *, shard_id: Optional[int], code: Optional[int] = None):
def __init__(self, socket, *, shard_id, code=None):
# This exception is just the same exception except
# reconfigured to subclass ClientException for users
self.code: int = code or socket.close_code or -1
self.code = code or socket.close_code
# aiohttp doesn't seem to consistently provide close reason
self.reason: str = ""
self.shard_id: Optional[int] = shard_id
super().__init__(f"Shard ID {self.shard_id} WebSocket closed with {self.code}")
self.reason = ''
self.shard_id = shard_id
super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}')
class PrivilegedIntentsRequired(ClientException):
"""Exception that's raised when the gateway is requesting privileged intents
"""Exception that's thrown when the gateway is requesting privileged intents
but they're not ticked in the developer page yet.
Go to https://discord.com/developers/applications/ and enable the intents
@ -247,31 +206,10 @@ class PrivilegedIntentsRequired(ClientException):
The shard ID that got closed if applicable.
"""
def __init__(self, shard_id: Optional[int]):
self.shard_id: Optional[int] = shard_id
msg = (
"Shard ID %s is requesting privileged intents that have not been explicitly enabled in the "
"developer portal. It is recommended to go to https://discord.com/developers/applications/ "
"and explicitly enable the privileged intents within your application's page. If this is not "
"possible, then consider disabling the privileged intents instead."
)
def __init__(self, shard_id):
self.shard_id = shard_id
msg = 'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the ' \
'developer portal. It is recommended to go to https://discord.com/developers/applications/ ' \
'and explicitly enable the privileged intents within your application\'s page. If this is not ' \
'possible, then consider disabling the privileged intents instead.'
super().__init__(msg % shard_id)
class InteractionResponded(ClientException):
"""Exception that's raised when sending another interaction response using
:class:`InteractionResponse` when one has already been done before.
An interaction can only respond once.
.. versionadded:: 2.0
Attributes
-----------
interaction: :class:`Interaction`
The interaction that's already been responded to.
"""
def __init__(self, interaction: Interaction):
self.interaction: Interaction = interaction
super().__init__("This interaction has already been responded to before")

View File

@ -22,28 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
if TYPE_CHECKING:
from .context import Context
from .cog import Cog
from .errors import CommandError
T = TypeVar("T")
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]
]
# This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution.
class _BaseCommand:

View File

@ -22,129 +22,38 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import collections
import collections.abc
from functools import cached_property
import inspect
import importlib.util
import sys
import traceback
import types
from collections import defaultdict
from discord.http import HTTPClient
from typing import (
Any,
Callable,
Iterable,
Tuple,
cast,
Mapping,
List,
Dict,
TYPE_CHECKING,
Optional,
TypeVar,
Type,
Union,
)
import discord
from discord.types.interactions import (
ApplicationCommandInteractionData,
ApplicationCommandInteractionDataOption,
EditApplicationCommand,
_ApplicationCommandInteractionDataOptionString,
)
from .core import GroupMixin
from .converter import Greedy
from .view import StringView, supported_quotes
from .view import StringView
from .context import Context
from .flags import FlagConverter
from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
if TYPE_CHECKING:
import importlib.machinery
from discord.role import Role
from discord.message import Message
from discord.abc import PartialMessageableChannel
from ._types import (
Check,
CoroFunc,
)
__all__ = (
"when_mentioned",
"when_mentioned_or",
"Bot",
"AutoShardedBot",
'when_mentioned',
'when_mentioned_or',
'Bot',
'AutoShardedBot',
)
MISSING: Any = discord.utils.MISSING
T = TypeVar("T")
CFT = TypeVar("CFT", bound="CoroFunc")
CXT = TypeVar("CXT", bound="Context")
class _FakeSlashMessage(discord.PartialMessage):
activity = application = edited_at = reference = webhook_id = None
attachments = components = reactions = stickers = []
tts = False
raw_mentions = discord.Message.raw_mentions
clean_content = discord.Message.clean_content
channel_mentions = discord.Message.channel_mentions
raw_role_mentions = discord.Message.raw_role_mentions
raw_channel_mentions = discord.Message.raw_channel_mentions
author: Union[discord.User, discord.Member]
@classmethod
def from_interaction(
cls, interaction: discord.Interaction, channel: Union[discord.TextChannel, discord.DMChannel, discord.Thread]
):
self = cls(channel=channel, id=interaction.id)
assert interaction.user is not None
self.author = interaction.user
return self
@cached_property
def mentions(self) -> List[Union[discord.Member, discord.User]]:
client = self._state._get_client()
if self.guild:
ensure_user = lambda id: self.guild.get_member(id) or client.get_user(id) # type: ignore
else:
ensure_user = client.get_user
return discord.utils._unique(filter(None, map(ensure_user, self.raw_mentions)))
@cached_property
def role_mentions(self) -> List[Role]:
if self.guild is None:
return []
return discord.utils._unique(filter(None, map(self.guild.get_role, self.raw_role_mentions)))
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
def when_mentioned(bot, msg):
"""A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
"""
# bot.user will never be None when this is called
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> ']
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]:
def when_mentioned_or(*prefixes):
"""A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@ -173,7 +82,6 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
----------
:func:`.when_mentioned`
"""
def inner(bot, msg):
r = list(prefixes)
r = when_mentioned(bot, msg) + r
@ -181,88 +89,37 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
return inner
def _is_submodule(parent: str, child: str) -> bool:
def _is_submodule(parent, child):
return parent == child or child.startswith(parent + ".")
def _unwrap_slash_groups(
data: ApplicationCommandInteractionData,
) -> Tuple[str, Dict[str, ApplicationCommandInteractionDataOption]]:
command_name = data["name"]
command_options: Any = data.get("options") or []
while True:
try:
option = next(o for o in command_options if o["type"] in {1, 2})
except StopIteration:
return command_name, {o["name"]: o for o in command_options}
else:
command_name += f' {option["name"]}'
command_options = option.get("options") or []
def _quote_string_safe(string: str) -> str:
# we need to quote this string otherwise we may spill into
# other parameters and cause all kinds of trouble, as many
# quotes are supported and some may be in the option, we
# loop through all supported quotes and if neither open or
# close are in the string, we add them
for open, close in supported_quotes.items():
if open not in string and close not in string:
return f"{open}{string}{close}"
# all supported quotes are in the message and we cannot add any
# safely, very unlikely but still got to be covered
raise errors.UnexpectedQuoteError(string)
class _DefaultRepr:
def __repr__(self):
return "<default-help-command>"
return '<default-help-command>'
_default = _DefaultRepr()
class BotBase(GroupMixin):
def __init__(
self,
command_prefix,
help_command=_default,
description=None,
*,
intents: discord.Intents,
message_commands: bool = True,
slash_commands: bool = False,
**options,
):
super().__init__(**options, intents=intents)
def __init__(self, command_prefix, help_command=_default, description=None, **options):
super().__init__(**options)
self.command_prefix = command_prefix
self.slash_commands = slash_commands
self.message_commands = message_commands
self.extra_events: Dict[str, List[CoroFunc]] = {}
self.__cogs: Dict[str, Cog] = {}
self.__extensions: Dict[str, types.ModuleType] = {}
self._checks: List[Check] = []
self.extra_events = {}
self.__cogs = {}
self.__extensions = {}
self._checks = []
self._check_once = []
self._before_invoke = None
self._after_invoke = None
self._help_command = None
self.description = inspect.cleandoc(description) if description else ""
self.owner_id = options.get("owner_id")
self.owner_ids = options.get("owner_ids", set())
self.strip_after_prefix = options.get("strip_after_prefix", False)
self.slash_command_guilds: Optional[Iterable[int]] = options.get("slash_command_guilds", None)
self.description = inspect.cleandoc(description) if description else ''
self.owner_id = options.get('owner_id')
self.owner_ids = options.get('owner_ids', set())
self.strip_after_prefix = options.get('strip_after_prefix', False)
if self.owner_id and self.owner_ids:
raise TypeError("Both owner_id and owner_ids are set.")
raise TypeError('Both owner_id and owner_ids are set.')
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}")
if not (message_commands or slash_commands):
raise ValueError("Both message_commands and slash_commands are disabled.")
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}')
if help_command is _default:
self.help_command = DefaultHelpCommand()
@ -271,64 +128,13 @@ class BotBase(GroupMixin):
# internal helpers
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None:
# super() will resolve to Client
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = "on_" + event_name
def dispatch(self, event_name, *args, **kwargs):
super().dispatch(event_name, *args, **kwargs)
ev = 'on_' + event_name
for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) # type: ignore
self._schedule_event(event, ev, *args, **kwargs)
async def setup(self):
await self.create_slash_commands()
async def create_slash_commands(self):
commands: defaultdict[Optional[int], List[EditApplicationCommand]] = defaultdict(list)
for command in self.commands:
if command.hidden or (command.slash_command is None and not self.slash_commands):
continue
try:
payload = command.to_application_command()
except Exception:
raise errors.ApplicationCommandRegistrationError(command)
if payload is None:
continue
guilds = command.slash_command_guilds or self.slash_command_guilds
if guilds is None:
commands[None].append(payload)
else:
for guild in guilds:
commands[guild].append(payload)
http: HTTPClient = self.http # type: ignore
global_commands = commands.pop(None, None)
application_id = self.application_id or (await self.application_info()).id # type: ignore
if global_commands is not None:
if self.slash_command_guilds is None:
await http.bulk_upsert_global_commands(
payload=global_commands,
application_id=application_id,
)
else:
for guild in self.slash_command_guilds:
await http.bulk_upsert_guild_commands(
guild_id=guild,
payload=global_commands,
application_id=application_id,
)
for guild, guild_commands in commands.items():
assert guild is not None
await http.bulk_upsert_guild_commands(
guild_id=guild,
payload=guild_commands,
application_id=application_id,
)
@discord.utils.copy_doc(discord.Client.close)
async def close(self) -> None:
async def close(self):
for extension in tuple(self.__extensions):
try:
self.unload_extension(extension)
@ -341,9 +147,9 @@ class BotBase(GroupMixin):
except Exception:
pass
await super().close() # type: ignore
await super().close()
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None:
async def on_command_error(self, context, exception):
"""|coro|
The default command error handler provided by the bot.
@ -353,7 +159,7 @@ class BotBase(GroupMixin):
This only fires if you do not specify any listeners for command error.
"""
if self.extra_events.get("on_command_error", None):
if self.extra_events.get('on_command_error', None):
return
command = context.command
@ -364,12 +170,12 @@ class BotBase(GroupMixin):
if cog and cog.has_error_handler():
return
print(f"Ignoring exception in command {context.command}:", file=sys.stderr)
print(f'Ignoring exception in command {context.command}:', file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
# global check registration
def check(self, func: T) -> T:
def check(self, func):
r"""A decorator that adds a global check to the bot.
A global check is similar to a :func:`.check` that is applied
@ -394,11 +200,10 @@ class BotBase(GroupMixin):
return ctx.command.qualified_name in allowed_commands
"""
# T was used instead of Check to ensure the type matches on return
self.add_check(func) # type: ignore
self.add_check(func)
return func
def add_check(self, func: Check, *, call_once: bool = False) -> None:
def add_check(self, func, *, call_once=False):
"""Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check`
@ -418,7 +223,7 @@ class BotBase(GroupMixin):
else:
self._checks.append(func)
def remove_check(self, func: Check, *, call_once: bool = False) -> None:
def remove_check(self, func, *, call_once=False):
"""Removes a global check from the bot.
This function is idempotent and will not raise an exception
@ -439,7 +244,7 @@ class BotBase(GroupMixin):
except ValueError:
pass
def check_once(self, func: CFT) -> CFT:
def check_once(self, func):
r"""A decorator that adds a "call once" global check to the bot.
Unlike regular global checks, this one is called only once
@ -477,16 +282,15 @@ class BotBase(GroupMixin):
self.add_check(func, call_once=True)
return func
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool:
async def can_run(self, ctx, *, call_once=False):
data = self._check_once if call_once else self._checks
if len(data) == 0:
return True
# type-checker doesn't distinguish between functions and methods
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
return await discord.utils.async_all(f(ctx) for f in data)
async def is_owner(self, user: discord.User) -> bool:
async def is_owner(self, user):
"""|coro|
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
@ -515,61 +319,15 @@ class BotBase(GroupMixin):
elif self.owner_ids:
return user.id in self.owner_ids
else:
# Populate the used fields, then retry the check. This is only done at-most once in the bot lifetime.
await self.populate_owners()
return await self.is_owner(user)
async def try_owners(self) -> List[discord.User]:
"""|coro|
Returns a list of :class:`~discord.User` representing the owners of the bot.
It uses the :attr:`owner_id` and :attr:`owner_ids`, if set.
.. versionadded:: 2.0
The function also checks if the application is team-owned if
:attr:`owner_ids` is not set.
Returns
--------
List[:class:`~discord.User`]
List of owners of the bot.
"""
if self.owner_id:
owner = await self.try_user(self.owner_id)
if owner:
return [owner]
app = await self.application_info()
if app.team:
self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids
else:
return []
self.owner_id = owner_id = app.owner.id
return user.id == owner_id
elif self.owner_ids:
owners = []
for owner_id in self.owner_ids:
owner = await self.try_user(owner_id)
if owner:
owners.append(owner)
return owners
else:
# We didn't have owners cached yet, cache them and retry.
await self.populate_owners()
return await self.try_owners()
async def populate_owners(self):
"""|coro|
Populate the :attr:`owner_id` and :attr:`owner_ids` through the use of :meth:`~.Bot.application_info`.
.. versionadded:: 2.0
"""
app = await self.application_info() # type: ignore
if app.team:
self.owner_ids = {m.id for m in app.team.members}
else:
self.owner_id = app.owner.id
def before_invoke(self, coro: CFT) -> CFT:
def before_invoke(self, coro):
"""A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is
@ -596,12 +354,12 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError("The pre-invoke hook must be a coroutine.")
raise TypeError('The pre-invoke hook must be a coroutine.')
self._before_invoke = coro
return coro
def after_invoke(self, coro: CFT) -> CFT:
def after_invoke(self, coro):
r"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is
@ -629,21 +387,21 @@ class BotBase(GroupMixin):
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError("The post-invoke hook must be a coroutine.")
raise TypeError('The post-invoke hook must be a coroutine.')
self._after_invoke = coro
return coro
# listener registration
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None:
def add_listener(self, func, name=None):
"""The non decorator alternative to :meth:`.listen`.
Parameters
-----------
func: :ref:`coroutine <coroutine>`
The function to call.
name: :class:`str`
name: Optional[:class:`str`]
The name of the event to listen for. Defaults to ``func.__name__``.
Example
@ -658,17 +416,17 @@ class BotBase(GroupMixin):
bot.add_listener(my_message, 'on_message')
"""
name = func.__name__ if name is MISSING else name
name = func.__name__ if name is None else name
if not asyncio.iscoroutinefunction(func):
raise TypeError("Listeners must be coroutines")
raise TypeError('Listeners must be coroutines')
if name in self.extra_events:
self.extra_events[name].append(func)
else:
self.extra_events[name] = [func]
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None:
def remove_listener(self, func, name=None):
"""Removes a listener from the pool of listeners.
Parameters
@ -680,7 +438,7 @@ class BotBase(GroupMixin):
``func.__name__``.
"""
name = func.__name__ if name is MISSING else name
name = func.__name__ if name is None else name
if name in self.extra_events:
try:
@ -688,7 +446,7 @@ class BotBase(GroupMixin):
except ValueError:
pass
def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]:
def listen(self, name=None):
"""A decorator that registers another function as an external
event listener. Basically this allows you to listen to multiple
events from different places e.g. such as :func:`.on_ready`
@ -718,7 +476,7 @@ class BotBase(GroupMixin):
The function being listened to is not a coroutine.
"""
def decorator(func: CFT) -> CFT:
def decorator(func):
self.add_listener(func, name)
return func
@ -757,20 +515,20 @@ class BotBase(GroupMixin):
"""
if not isinstance(cog, Cog):
raise TypeError("cogs must derive from Cog")
raise TypeError('cogs must derive from Cog')
cog_name = cog.__cog_name__
existing = self.__cogs.get(cog_name)
if existing is not None:
if not override:
raise discord.ClientException(f"Cog named {cog_name!r} already loaded")
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
self.remove_cog(cog_name)
cog = cog._inject(self)
self.__cogs[cog_name] = cog
def get_cog(self, name: str) -> Optional[Cog]:
def get_cog(self, name):
"""Gets the cog instance requested.
If the cog is not found, ``None`` is returned instead.
@ -789,8 +547,8 @@ class BotBase(GroupMixin):
"""
return self.__cogs.get(name)
def remove_cog(self, name: str) -> Optional[Cog]:
"""Removes a cog from the bot and returns it.
def remove_cog(self, name):
"""Removes a cog from the bot.
All registered commands and event listeners that the
cog has registered will be removed as well.
@ -801,11 +559,6 @@ class BotBase(GroupMixin):
-----------
name: :class:`str`
The name of the cog to remove.
Returns
-------
Optional[:class:`.Cog`]
The cog that was removed. ``None`` if not found.
"""
cog = self.__cogs.pop(name, None)
@ -817,16 +570,14 @@ class BotBase(GroupMixin):
help_command.cog = None
cog._eject(self)
return cog
@property
def cogs(self) -> Mapping[str, Cog]:
def cogs(self):
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
return types.MappingProxyType(self.__cogs)
# extensions
def _remove_module_references(self, name: str) -> None:
def _remove_module_references(self, name):
# find all references to the module
# remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items():
@ -850,9 +601,9 @@ class BotBase(GroupMixin):
for index in reversed(remove):
del event_list[index]
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None:
def _call_module_finalizers(self, lib, key):
try:
func = getattr(lib, "teardown")
func = getattr(lib, 'teardown')
except AttributeError:
pass
else:
@ -868,18 +619,18 @@ class BotBase(GroupMixin):
if _is_submodule(name, module):
del sys.modules[module]
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None:
def _load_from_module_spec(self, spec, key):
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
try:
spec.loader.exec_module(lib) # type: ignore
spec.loader.exec_module(lib)
except Exception as e:
del sys.modules[key]
raise errors.ExtensionFailed(key, e) from e
try:
setup = getattr(lib, "setup")
setup = getattr(lib, 'setup')
except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key)
@ -894,13 +645,13 @@ class BotBase(GroupMixin):
else:
self.__extensions[key] = lib
def _resolve_name(self, name: str, package: Optional[str]) -> str:
def _resolve_name(self, name, package):
try:
return importlib.util.resolve_name(name, package)
except ImportError:
raise errors.ExtensionNotFound(name)
def load_extension(self, name: str, *, package: Optional[str] = None) -> None:
def load_extension(self, name, *, package=None):
"""Loads an extension.
An extension is a python module that contains commands, cogs, or
@ -947,7 +698,7 @@ class BotBase(GroupMixin):
self._load_from_module_spec(spec, name)
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None:
def unload_extension(self, name, *, package=None):
"""Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are
@ -988,7 +739,7 @@ class BotBase(GroupMixin):
self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name)
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None:
def reload_extension(self, name, *, package=None):
"""Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is
@ -1029,7 +780,11 @@ class BotBase(GroupMixin):
raise errors.ExtensionNotLoaded(name)
# get the previous module states from sys modules
modules = {name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name)}
modules = {
name: module
for name, module in sys.modules.items()
if _is_submodule(lib.__name__, name)
}
try:
# Unload and then load the module...
@ -1040,7 +795,7 @@ class BotBase(GroupMixin):
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
lib.setup(self) # type: ignore
lib.setup(self)
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller
@ -1048,21 +803,21 @@ class BotBase(GroupMixin):
raise
@property
def extensions(self) -> Mapping[str, types.ModuleType]:
def extensions(self):
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
return types.MappingProxyType(self.__extensions)
# help command stuff
@property
def help_command(self) -> Optional[HelpCommand]:
def help_command(self):
return self._help_command
@help_command.setter
def help_command(self, value: Optional[HelpCommand]) -> None:
def help_command(self, value):
if value is not None:
if not isinstance(value, HelpCommand):
raise TypeError("help_command must be a subclass of HelpCommand")
raise TypeError('help_command must be a subclass of HelpCommand')
if self._help_command is not None:
self._help_command._remove_from_bot(self)
self._help_command = value
@ -1075,7 +830,7 @@ class BotBase(GroupMixin):
# command processing
async def get_prefix(self, message: Message) -> Union[List[str], str]:
async def get_prefix(self, message):
"""|coro|
Retrieves the prefix the bot is listening to
@ -1092,9 +847,6 @@ class BotBase(GroupMixin):
A list of prefixes or a single prefix that the bot is
listening for.
"""
if isinstance(message, _FakeSlashMessage):
return "/"
prefix = ret = self.command_prefix
if callable(prefix):
ret = await discord.utils.maybe_coroutine(prefix, self, message)
@ -1108,17 +860,15 @@ class BotBase(GroupMixin):
if isinstance(ret, collections.abc.Iterable):
raise
raise TypeError(
"command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}"
)
raise TypeError("command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}")
if not ret:
raise ValueError("Iterable command_prefix must contain at least one prefix")
return ret
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT:
async def get_context(self, message, *, cls=Context):
r"""|coro|
Returns the invocation context from the message.
@ -1151,7 +901,7 @@ class BotBase(GroupMixin):
view = StringView(message.content)
ctx = cls(prefix=None, view=view, bot=self, message=message)
if message.author.id == self.user.id: # type: ignore
if message.author.id == self.user.id:
return ctx
prefix = await self.get_prefix(message)
@ -1171,18 +921,14 @@ class BotBase(GroupMixin):
except TypeError:
if not isinstance(prefix, list):
raise TypeError(
"get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}"
)
raise TypeError("get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}")
# It's possible a bad command_prefix got us here.
for value in prefix:
if not isinstance(value, str):
raise TypeError(
"Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}"
)
raise TypeError("Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}")
# Getting here shouldn't happen
raise
@ -1192,12 +938,11 @@ class BotBase(GroupMixin):
invoker = view.get_word()
ctx.invoked_with = invoker
# type-checker fails to narrow invoked_prefix type.
ctx.prefix = invoked_prefix # type: ignore
ctx.prefix = invoked_prefix
ctx.command = self.all_commands.get(invoker)
return ctx
async def invoke(self, ctx: Context) -> None:
async def invoke(self, ctx):
"""|coro|
Invokes the command given under the invocation context and
@ -1209,21 +954,21 @@ class BotBase(GroupMixin):
The invocation context to invoke.
"""
if ctx.command is not None:
self.dispatch("command", ctx)
self.dispatch('command', ctx)
try:
if await self.can_run(ctx, call_once=True):
await ctx.command.invoke(ctx)
else:
raise errors.CheckFailure("The global check once functions failed.")
raise errors.CheckFailure('The global check once functions failed.')
except errors.CommandError as exc:
await ctx.command.dispatch_error(ctx, exc)
else:
self.dispatch("command_completion", ctx)
self.dispatch('command_completion', ctx)
elif ctx.invoked_with:
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch("command_error", ctx, exc)
self.dispatch('command_error', ctx, exc)
async def process_commands(self, message: Message) -> None:
async def process_commands(self, message):
"""|coro|
This function processes the commands that have been registered
@ -1251,95 +996,9 @@ class BotBase(GroupMixin):
ctx = await self.get_context(message)
await self.invoke(ctx)
async def process_slash_commands(self, interaction: discord.Interaction):
"""|coro|
This function processes a slash command interaction into a usable
message and calls :meth:`.process_commands` based on it. Without this
coroutine slash commands will not be triggered.
By default, this coroutine is called inside the :func:`.on_interaction`
event. If you choose to override the :func:`.on_interaction` event,
then you should invoke this coroutine as well.
.. versionadded:: 2.0
Parameters
-----------
interaction: :class:`discord.Interaction`
The interaction to process slash commands for.
"""
if interaction.type != discord.InteractionType.application_command:
return
interaction.data = cast(ApplicationCommandInteractionData, interaction.data)
command_name, command_options = _unwrap_slash_groups(interaction.data)
command = self.get_command(command_name)
if command is None:
raise errors.CommandNotFound(f'Command "{command_name}" is not found')
# Ensure the interaction channel is usable
channel = interaction.channel
if channel is None or isinstance(channel, discord.PartialMessageable):
if interaction.guild is None:
assert interaction.user is not None
channel = await interaction.user.create_dm()
elif interaction.channel_id is not None:
channel = await interaction.guild.fetch_channel(interaction.channel_id)
else:
return # cannot do anything without stable channel
# Make our fake message so we can pass it to ext.commands
message: discord.Message = _FakeSlashMessage.from_interaction(interaction, channel) # type: ignore
message.content = f"/{command_name}"
# Add arguments to fake message content, in the right order
ignore_params: List[inspect.Parameter] = []
for name, param in command.clean_params.items():
if inspect.isclass(param.annotation) and issubclass(param.annotation, FlagConverter):
for name, flag in param.annotation.get_flags().items():
option = command_options.get(name)
if option is None:
if flag.required:
raise errors.MissingRequiredFlag(flag)
else:
prefix = param.annotation.__commands_flag_prefix__
delimiter = param.annotation.__commands_flag_delimiter__
message.content += f" {prefix}{name}{delimiter}{option['value']}" # type: ignore
continue
option = command_options.get(name)
if option is None:
if param.default is param.empty and not command._is_typing_optional(param.annotation):
raise errors.MissingRequiredArgument(param)
else:
ignore_params.append(param)
elif (
option["type"] == 3
and not isinstance(param.annotation, Greedy)
and param.kind in {param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY}
):
# String with space in without "consume rest"
option = cast(_ApplicationCommandInteractionDataOptionString, option)
message.content += f" {_quote_string_safe(option['value'])}"
else:
message.content += f' {option.get("value", "")}'
ctx = await self.get_context(message)
ctx._ignored_params = ignore_params
ctx.interaction = interaction
await self.invoke(ctx)
async def on_message(self, message):
await self.process_commands(message)
async def on_interaction(self, interaction: discord.Interaction):
await self.process_slash_commands(interaction)
class Bot(BotBase, discord.Client):
"""Represents a discord bot.
@ -1382,7 +1041,7 @@ class Bot(BotBase, discord.Client):
when passing an empty string, it should always be last as no prefix
after it will be matched.
case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``True``. This
Whether the commands should be case insensitive. Defaults to ``False``. This
attribute does not carry over to groups. You must set it to every group if
you require group commands to be case insensitive as well.
description: :class:`str`
@ -1409,36 +1068,11 @@ class Bot(BotBase, discord.Client):
the ``command_prefix`` is set to ``!``. Defaults to ``False``.
.. versionadded:: 1.7
message_commands: Optional[:class:`bool`]
Whether to process commands based on messages.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``message_command`` parameter
.. versionadded:: 2.0
slash_commands: Optional[:class:`bool`]
Whether to upload and process slash commands.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``slash_command`` parameter
.. versionadded:: 2.0
slash_command_guilds: Optional[:class:`List[int]`]
If this is set, only upload slash commands to these guild IDs.
Can be overwritten per command in the command decorators or when making
a :class:`Command` object via the ``slash_command_guilds`` parameter
.. versionadded:: 2.0
"""
pass
class AutoShardedBot(BotBase, discord.AutoShardedClient):
"""This is similar to :class:`.Bot` except that it is inherited from
:class:`discord.AutoShardedClient` instead.
"""
pass

View File

@ -21,31 +21,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import inspect
import discord.utils
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
import copy
from ._types import _BaseCommand
if TYPE_CHECKING:
from .bot import BotBase
from .context import Context
from .core import Command
__all__ = (
"CogMeta",
"Cog",
'CogMeta',
'Cog',
)
CogT = TypeVar("CogT", bound="Cog")
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
class CogMeta(type):
"""A metaclass for defining a cog.
@ -106,24 +91,19 @@ class CogMeta(type):
pass # hidden -> False
"""
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta:
def __new__(cls, *args, **kwargs):
name, bases, attrs = args
attrs["__cog_name__"] = kwargs.pop("name", name)
attrs["__cog_settings__"] = kwargs.pop("command_attrs", {})
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
description = kwargs.pop("description", None)
description = kwargs.pop('description', None)
if description is None:
description = inspect.cleandoc(attrs.get("__doc__", ""))
attrs["__cog_description__"] = description
description = inspect.cleandoc(attrs.get('__doc__', ''))
attrs['__cog_description__'] = description
commands = {}
listeners = {}
no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})"
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
for base in reversed(new_cls.__mro__):
@ -138,21 +118,21 @@ class CogMeta(type):
value = value.__func__
if isinstance(value, _BaseCommand):
if is_static_method:
raise TypeError(f"Command in method {base}.{elem!r} must not be staticmethod.")
if elem.startswith(("cog_", "bot_")):
raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.')
if elem.startswith(('cog_', 'bot_')):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value
elif inspect.iscoroutinefunction(value):
try:
getattr(value, "__cog_listener__")
getattr(value, '__cog_listener__')
except AttributeError:
continue
else:
if elem.startswith(("cog_", "bot_")):
if elem.startswith(('cog_', 'bot_')):
raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
listeners_as_list = []
for listener in listeners.values():
@ -164,19 +144,17 @@ class CogMeta(type):
new_cls.__cog_listeners__ = listeners_as_list
return new_cls
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *args, **kwargs):
super().__init__(*args)
@classmethod
def qualified_name(cls) -> str:
def qualified_name(cls):
return cls.__cog_name__
def _cog_special_method(func: FuncT) -> FuncT:
def _cog_special_method(func):
func.__cog_special_method__ = None
return func
class Cog(metaclass=CogMeta):
"""The base class that all cogs must inherit from.
@ -188,12 +166,7 @@ class Cog(metaclass=CogMeta):
are equally valid here.
"""
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT:
def __new__(cls, *args, **kwargs):
# For issue 426, we need to store a copy of the command objects
# since we modify them to inject `self` to them.
# To do this, we need to interfere with the Cog creation process.
@ -201,10 +174,12 @@ class Cog(metaclass=CogMeta):
cmd_attrs = cls.__cog_settings__
# Either update the command with the cog provided defaults or copy it.
# r.e type ignore, type-checker complains about overriding a ClassVar
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__)
lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__}
lookup = {
cmd.qualified_name: cmd
for cmd in self.__cog_commands__
}
# Update the Command instances dynamically as well
for command in self.__cog_commands__:
@ -212,15 +187,15 @@ class Cog(metaclass=CogMeta):
parent = command.parent
if parent is not None:
# Get the latest parent reference
parent = lookup[parent.qualified_name] # type: ignore
parent = lookup[parent.qualified_name]
# Update our parent's reference to our self
parent.remove_command(command.name) # type: ignore
parent.add_command(command) # type: ignore
parent.remove_command(command.name)
parent.add_command(command)
return self
def get_commands(self) -> List[Command]:
def get_commands(self):
r"""
Returns
--------
@ -235,20 +210,20 @@ class Cog(metaclass=CogMeta):
return [c for c in self.__cog_commands__ if c.parent is None]
@property
def qualified_name(self) -> str:
def qualified_name(self):
""":class:`str`: Returns the cog's specified name, not the class name."""
return self.__cog_name__
@property
def description(self) -> str:
def description(self):
""":class:`str`: Returns the cog's description, typically the cleaned docstring."""
return self.__cog_description__
@description.setter
def description(self, description: str) -> None:
def description(self, description):
self.__cog_description__ = description
def walk_commands(self) -> Generator[Command, None, None]:
def walk_commands(self):
"""An iterator that recursively walks through this cog's commands and subcommands.
Yields
@ -257,14 +232,13 @@ class Cog(metaclass=CogMeta):
A command or group from the cog.
"""
from .core import GroupMixin
for command in self.__cog_commands__:
if command.parent is None:
yield command
if isinstance(command, GroupMixin):
yield from command.walk_commands()
def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]:
def get_listeners(self):
"""Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.
Returns
@ -275,12 +249,12 @@ class Cog(metaclass=CogMeta):
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
@classmethod
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]:
def _get_overridden_method(cls, method):
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, "__cog_special_method__", method)
return getattr(method.__func__, '__cog_special_method__', method)
@classmethod
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]:
def listener(cls, name=None):
"""A decorator that marks a function as a listener.
This is the cog equivalent of :meth:`.Bot.listen`.
@ -298,15 +272,15 @@ class Cog(metaclass=CogMeta):
the name.
"""
if name is not MISSING and not isinstance(name, str):
raise TypeError(f"Cog.listener expected str but received {name.__class__.__name__!r} instead.")
if name is not None and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
def decorator(func: FuncT) -> FuncT:
def decorator(func):
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
if not inspect.iscoroutinefunction(actual):
raise TypeError("Listener function must be a coroutine function.")
raise TypeError('Listener function must be a coroutine function.')
actual.__cog_listener__ = True
to_assign = name or actual.__name__
try:
@ -318,18 +292,17 @@ class Cog(metaclass=CogMeta):
# to pick it up but the metaclass unfurls the function and
# thus the assignments need to be on the actual function
return func
return decorator
def has_error_handler(self) -> bool:
def has_error_handler(self):
""":class:`bool`: Checks whether the cog has an error handler.
.. versionadded:: 1.7
"""
return not hasattr(self.cog_command_error.__func__, "__cog_special_method__")
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
@_cog_special_method
def cog_unload(self) -> None:
def cog_unload(self):
"""A special method that is called when the cog gets removed.
This function **cannot** be a coroutine. It must be a regular
@ -340,7 +313,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
def bot_check_once(self, ctx: Context) -> bool:
def bot_check_once(self, ctx):
"""A special method that registers as a :meth:`.Bot.check_once`
check.
@ -350,7 +323,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def bot_check(self, ctx: Context) -> bool:
def bot_check(self, ctx):
"""A special method that registers as a :meth:`.Bot.check`
check.
@ -360,7 +333,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
def cog_check(self, ctx: Context) -> bool:
def cog_check(self, ctx):
"""A special method that registers as a :func:`~discord.ext.commands.check`
for every command and subcommand in this cog.
@ -370,7 +343,7 @@ class Cog(metaclass=CogMeta):
return True
@_cog_special_method
async def cog_command_error(self, ctx: Context, error: Exception) -> None:
async def cog_command_error(self, ctx, error):
"""A special method that is called whenever an error
is dispatched inside this cog.
@ -389,7 +362,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_before_invoke(self, ctx: Context) -> None:
async def cog_before_invoke(self, ctx):
"""A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`.
@ -404,7 +377,7 @@ class Cog(metaclass=CogMeta):
pass
@_cog_special_method
async def cog_after_invoke(self, ctx: Context) -> None:
async def cog_after_invoke(self, ctx):
"""A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`.
@ -418,7 +391,7 @@ class Cog(metaclass=CogMeta):
"""
pass
def _inject(self: CogT, bot: BotBase) -> CogT:
def _inject(self, bot):
cls = self.__class__
# realistically, the only thing that can cause loading errors
@ -453,7 +426,7 @@ class Cog(metaclass=CogMeta):
return self
def _eject(self, bot: BotBase) -> None:
def _eject(self, bot):
cls = self.__class__
try:

View File

@ -21,54 +21,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import inspect
import re
from datetime import timedelta
from typing import Any, Dict, Generic, List, Literal, NoReturn, Optional, TYPE_CHECKING, TypeVar, Union, overload
import discord.abc
import discord.utils
import re
from discord.message import Message
from discord import Permissions
__all__ = (
'Context',
)
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from discord.abc import MessageableChannel
from discord.guild import Guild
from discord.member import Member
from discord.state import ConnectionState
from discord.user import ClientUser, User
from discord.webhook import WebhookMessage
from discord.interactions import Interaction
from discord.voice_client import VoiceProtocol
from .bot import Bot, AutoShardedBot
from .cog import Cog
from .core import Command
from .help import HelpCommand
from .view import StringView
__all__ = ("Context",)
MISSING: Any = discord.utils.MISSING
T = TypeVar("T")
BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar("CogT", bound="Cog")
if TYPE_CHECKING:
P = ParamSpec("P")
else:
P = TypeVar("P")
class Context(discord.abc.Messageable, Generic[BotT]):
class Context(discord.abc.Messageable):
r"""Represents the context in which a command is being invoked under.
This class contains a lot of meta data to help you understand more about
@ -96,11 +58,11 @@ class Context(discord.abc.Messageable, Generic[BotT]):
This is only of use for within converters.
.. versionadded:: 2.0
prefix: Optional[:class:`str`]
prefix: :class:`str`
The prefix that was used to invoke the command.
command: Optional[:class:`Command`]
command: :class:`Command`
The command that is being invoked currently.
invoked_with: Optional[:class:`str`]
invoked_with: :class:`str`
The command name that triggered this invocation. Useful for finding out
which alias called the command.
invoked_parents: List[:class:`str`]
@ -111,7 +73,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 1.7
invoked_subcommand: Optional[:class:`Command`]
invoked_subcommand: :class:`Command`
The subcommand that was invoked.
If no valid subcommand was invoked then this is equal to ``None``.
subcommand_passed: Optional[:class:`str`]
@ -123,43 +85,24 @@ class Context(discord.abc.Messageable, Generic[BotT]):
A boolean that indicates if the command failed to be parsed, checked,
or invoked.
"""
interaction: Optional[Interaction] = None
def __init__(
self,
*,
message: Message,
bot: BotT,
view: StringView,
args: List[Any] = MISSING,
kwargs: Dict[str, Any] = MISSING,
prefix: Optional[str] = None,
command: Optional[Command] = None,
invoked_with: Optional[str] = None,
invoked_parents: List[str] = MISSING,
invoked_subcommand: Optional[Command] = None,
subcommand_passed: Optional[str] = None,
command_failed: bool = False,
current_parameter: Optional[inspect.Parameter] = None,
):
self.message: Message = message
self.bot: BotT = bot
self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command
self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter
self._ignored_params: List[inspect.Parameter] = []
self._typing_task: Optional[asyncio.Task[NoReturn]] = None
self._state: ConnectionState = self.message._state
def __init__(self, **attrs):
self.message = attrs.pop('message', None)
self.bot = attrs.pop('bot', None)
self.args = attrs.pop('args', [])
self.kwargs = attrs.pop('kwargs', {})
self.prefix = attrs.pop('prefix')
self.command = attrs.pop('command', None)
self.view = attrs.pop('view', None)
self.invoked_with = attrs.pop('invoked_with', None)
self.invoked_parents = attrs.pop('invoked_parents', [])
self.invoked_subcommand = attrs.pop('invoked_subcommand', None)
self.subcommand_passed = attrs.pop('subcommand_passed', None)
self.command_failed = attrs.pop('command_failed', False)
self.current_parameter = attrs.pop('current_parameter', None)
self._state = self.message._state
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T:
async def invoke(self, command, /, *args, **kwargs):
r"""|coro|
Calls a command with the arguments given.
@ -181,7 +124,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
command: :class:`.Command`
The command that is going to be called.
\*args
The arguments to use.
The arguments to to use.
\*\*kwargs
The keyword arguments to use.
@ -190,9 +133,17 @@ class Context(discord.abc.Messageable, Generic[BotT]):
TypeError
The command argument to invoke is missing.
"""
return await command(self, *args, **kwargs)
arguments = []
if command.cog is not None:
arguments.append(command.cog)
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None:
arguments.append(self)
arguments.extend(args)
ret = await command.callback(*arguments, **kwargs)
return ret
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True):
"""|coro|
Calls the command again.
@ -225,7 +176,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
cmd = self.command
view = self.view
if cmd is None:
raise ValueError("This context is not valid.")
raise ValueError('This context is not valid.')
# some state to revert to when we're done
index, previous = view.index, view.previous
@ -236,10 +187,10 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if restart:
to_call = cmd.root_parent or cmd
view.index = len(self.prefix or "")
view.index = len(self.prefix)
view.previous = 0
self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command
self.invoked_with = view.get_word() # advance to get the root command
else:
to_call = cmd
@ -255,32 +206,29 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.subcommand_passed = subcommand_passed
@property
def valid(self) -> bool:
def valid(self):
""":class:`bool`: Checks if the invocation context is valid to be invoked with."""
return self.prefix is not None and self.command is not None
async def _get_channel(self) -> discord.abc.Messageable:
async def _get_channel(self):
return self.channel
@property
def clean_prefix(self) -> str:
def clean_prefix(self):
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
.. versionadded:: 2.0
"""
if self.prefix is None:
return ""
user = self.me
user = self.guild.me if self.guild else self.bot.user
# this breaks if the prefix mention is not the bot itself but I
# consider this to be an *incredibly* strange use case. I'd rather go
# for this common use case rather than waste performance for the
# odd one.
pattern = re.compile(r"<@!?%s>" % user.id)
return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix)
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
@property
def cog(self) -> Optional[Cog]:
def cog(self):
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist."""
if self.command is None:
@ -288,46 +236,38 @@ class Context(discord.abc.Messageable, Generic[BotT]):
return self.command.cog
@discord.utils.cached_property
def guild(self) -> Optional[Guild]:
def guild(self):
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
return self.message.guild
@discord.utils.cached_property
def channel(self) -> MessageableChannel:
def channel(self):
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
Shorthand for :attr:`.Message.channel`.
"""
return self.message.channel
@discord.utils.cached_property
def author(self) -> Union[User, Member]:
def author(self):
"""Union[:class:`~discord.User`, :class:`.Member`]:
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
"""
return self.message.author
@discord.utils.cached_property
def me(self) -> Union[Member, ClientUser]:
def me(self):
"""Union[:class:`.Member`, :class:`.ClientUser`]:
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
"""
# bot.user will never be None at this point.
return self.guild.me if self.guild is not None else self.bot.user # type: ignore
return self.guild.me if self.guild is not None else self.bot.user
@property
def voice_client(self) -> Optional[VoiceProtocol]:
def voice_client(self):
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
g = self.guild
return g.voice_client if g else None
def author_permissions(self) -> Permissions:
"""Returns the author permissions in the given channel.
.. versionadded:: 2.0
"""
return self.channel.permissions_for(self.author)
async def send_help(self, *args: Any) -> Any:
async def send_help(self, *args):
"""send_help(entity=<bot>)
|coro|
@ -379,12 +319,12 @@ class Context(discord.abc.Messageable, Generic[BotT]):
return None
entity = args[0]
if isinstance(entity, str):
entity = bot.get_cog(entity) or bot.get_command(entity)
if entity is None:
return None
if isinstance(entity, str):
entity = bot.get_cog(entity) or bot.get_command(entity)
try:
entity.qualified_name
except AttributeError:
@ -394,7 +334,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
await cmd.prepare_help_command(self, entity.qualified_name)
try:
if hasattr(entity, "__cog_commands__"):
if hasattr(entity, '__cog_commands__'):
injected = wrap_callback(cmd.send_cog_help)
return await injected(entity)
elif isinstance(entity, Group):
@ -408,128 +348,6 @@ class Context(discord.abc.Messageable, Generic[BotT]):
except CommandError as e:
await cmd.on_help_command_error(self, e)
@overload
async def send(
self,
content: Optional[str] = None,
return_message: Literal[False] = False,
ephemeral: bool = False,
**kwargs: Any,
) -> Optional[Union[Message, WebhookMessage]]:
...
@overload
async def send(
self,
content: Optional[str] = None,
return_message: Literal[True] = True,
ephemeral: bool = False,
**kwargs: Any,
) -> Union[Message, WebhookMessage]:
...
async def send(
self, content: Optional[str] = None, return_message: bool = True, ephemeral: bool = False, **kwargs: Any
) -> Optional[Union[Message, WebhookMessage]]:
"""
|coro|
A shortcut method to :meth:`.abc.Messageable.send` with interaction helpers.
This function takes all the parameters of :meth:`.abc.Messageable.send` plus the following:
Parameters
------------
return_message: :class:`bool`
Ignored if not in a slash command context.
If this is set to False more native interaction methods will be used.
ephemeral: :class:`bool`
Ignored if not in a slash command context.
Indicates if the message should only be visible to the user who started the interaction.
If a view is sent with an ephemeral message and it has no timeout set then the timeout
is set to 15 minutes.
Returns
--------
Optional[Union[:class:`.Message`, :class:`.WebhookMessage`]]
In a slash command context, the message that was sent if return_message is True.
In a normal context, it always returns a :class:`.Message`
"""
if self._typing_task is not None:
self._typing_task.cancel()
self._typing_task = None
if self.interaction is None or (
self.interaction.response.responded_at is not None
and discord.utils.utcnow() - self.interaction.response.responded_at >= timedelta(minutes=15)
):
return await super().send(content, **kwargs)
# Remove unsupported arguments from kwargs
kwargs.pop("nonce", None)
kwargs.pop("stickers", None)
kwargs.pop("reference", None)
kwargs.pop("delete_after", None)
kwargs.pop("mention_author", None)
if not (
return_message
or self.interaction.response.is_done()
or any(arg in kwargs for arg in ("file", "files", "allowed_mentions"))
):
send = self.interaction.response.send_message
else:
# We have to defer in order to use the followup webhook
if not self.interaction.response.is_done():
await self.interaction.response.defer(ephemeral=ephemeral)
send = self.interaction.followup.send
return await send(content, ephemeral=ephemeral, **kwargs) # type: ignore
@overload
async def reply(
self, content: Optional[str] = None, return_message: Literal[False] = False, **kwargs: Any
) -> Optional[Union[Message, WebhookMessage]]:
...
@overload
async def reply(
self, content: Optional[str] = None, return_message: Literal[True] = True, **kwargs: Any
) -> Union[Message, WebhookMessage]:
...
@discord.utils.copy_doc(Message.reply)
async def reply(
self, content: Optional[str] = None, return_message: bool = True, **kwargs: Any
) -> Optional[Union[Message, WebhookMessage]]:
return await self.send(content, return_message=return_message, reference=self.message, **kwargs) # type: ignore
async def defer(self, *, ephemeral: bool = False, trigger_typing: bool = True) -> None:
"""|coro|
Defers the Slash Command interaction if ran in a slash command **or**
Loops triggering ``Bot is typing`` in the channel if run in a message command.
Parameters
------------
trigger_typing: :class:`bool`
Indicates whether to trigger typing in a message command.
ephemeral: :class:`bool`
Indicates whether the deferred message will eventually be ephemeral in a slash command.
"""
if self.interaction is None:
if self._typing_task is None and trigger_typing:
async def typing_task():
while True:
await self.trigger_typing()
await asyncio.sleep(10)
self._typing_task = self.bot.loop.create_task(typing_task())
else:
await self.interaction.response.defer(ephemeral=ephemeral)
@discord.utils.copy_doc(discord.Message.reply)
async def reply(self, content=None, **kwargs):
return await self.message.reply(content, **kwargs)

View File

@ -48,37 +48,33 @@ from .errors import *
if TYPE_CHECKING:
from .context import Context
from discord.message import PartialMessageableChannel
__all__ = (
"Converter",
"ObjectConverter",
"MemberConverter",
"UserConverter",
"MessageConverter",
"PartialMessageConverter",
"TextChannelConverter",
"InviteConverter",
"GuildConverter",
"RoleConverter",
"GameConverter",
"ColourConverter",
"ColorConverter",
"VoiceChannelConverter",
"StageChannelConverter",
"EmojiConverter",
"PartialEmojiConverter",
"CategoryChannelConverter",
"IDConverter",
"StoreChannelConverter",
"ThreadConverter",
"GuildChannelConverter",
"GuildStickerConverter",
"clean_content",
"Greedy",
"Option",
"run_converters",
'Converter',
'ObjectConverter',
'MemberConverter',
'UserConverter',
'MessageConverter',
'PartialMessageConverter',
'TextChannelConverter',
'InviteConverter',
'GuildConverter',
'RoleConverter',
'GameConverter',
'ColourConverter',
'ColorConverter',
'VoiceChannelConverter',
'StageChannelConverter',
'EmojiConverter',
'PartialEmojiConverter',
'CategoryChannelConverter',
'IDConverter',
'StoreChannelConverter',
'GuildChannelConverter',
'clean_content',
'Greedy',
'run_converters',
)
@ -92,12 +88,9 @@ def _get_from_guilds(bot, getter, argument):
_utils_get = discord.utils.get
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
CT = TypeVar("CT", bound=discord.abc.GuildChannel)
TT = TypeVar("TT", bound=discord.Thread)
DT = TypeVar("DT", bound=str)
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
@runtime_checkable
@ -135,10 +128,10 @@ class Converter(Protocol[T_co]):
:exc:`.BadArgument`
The converter failed to convert the argument.
"""
raise NotImplementedError("Derived classes need to implement this.")
raise NotImplementedError('Derived classes need to implement this.')
_ID_REGEX = re.compile(r"([0-9]{15,20})$")
_ID_REGEX = re.compile(r'([0-9]{15,20})$')
class IDConverter(Converter[T_co]):
@ -161,7 +154,7 @@ class ObjectConverter(IDConverter[discord.Object]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Object:
match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument)
if match is None:
raise ObjectNotFound(argument)
@ -195,8 +188,8 @@ class MemberConverter(IDConverter[discord.Member]):
async def query_member_named(self, guild, argument):
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == "#":
username, _, discriminator = argument.rpartition("#")
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator)
else:
@ -226,7 +219,7 @@ class MemberConverter(IDConverter[discord.Member]):
async def convert(self, ctx: Context, argument: str) -> discord.Member:
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
guild = ctx.guild
result = None
user_id = None
@ -235,13 +228,13 @@ class MemberConverter(IDConverter[discord.Member]):
if guild:
result = guild.get_member_named(argument)
else:
result = _get_from_guilds(bot, "get_member_named", argument)
result = _get_from_guilds(bot, 'get_member_named', argument)
else:
user_id = int(match.group(1))
if guild:
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
else:
result = _get_from_guilds(bot, "get_member", user_id)
result = _get_from_guilds(bot, 'get_member', user_id)
if result is None:
if guild is None:
@ -279,7 +272,7 @@ class UserConverter(IDConverter[discord.User]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument)
result = None
state = ctx._state
@ -297,12 +290,12 @@ class UserConverter(IDConverter[discord.User]):
arg = argument
# Remove the '@' character if this is the first character from the argument
if arg[0] == "@":
if arg[0] == '@':
# Remove first character
arg = arg[1:]
# check for discriminator if it exists,
if len(arg) > 5 and arg[-5] == "#":
if len(arg) > 5 and arg[-5] == '#':
discrim = arg[-4:]
name = arg[:-5]
predicate = lambda u: u.name == name and u.discriminator == discrim
@ -332,42 +325,22 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
"""
@staticmethod
def _get_id_matches(ctx, argument):
id_regex = re.compile(r"(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$")
def _get_id_matches(argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
link_regex = re.compile(
r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/"
r"(?P<guild_id>[0-9]{15,20}|@me)"
r"/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$"
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
r'(?:[0-9]{15,20}|@me)'
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$'
)
match = id_regex.match(argument) or link_regex.match(argument)
if not match:
raise MessageNotFound(argument)
data = match.groupdict()
channel_id = discord.utils._get_as_snowflake(data, "channel_id")
message_id = int(data["message_id"])
guild_id = data.get("guild_id")
if guild_id is None:
guild_id = ctx.guild and ctx.guild.id
elif guild_id == "@me":
guild_id = None
else:
guild_id = int(guild_id)
return guild_id, message_id, channel_id
@staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
if guild_id is not None:
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
else:
return None
else:
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
channel_id = match.group('channel_id')
return int(match.group('message_id')), int(channel_id) if channel_id else None
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id)
message_id, channel_id = self._get_id_matches(argument)
channel = ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
if not channel:
raise ChannelNotFound(channel_id)
return discord.PartialMessage(channel=channel, id=message_id)
@ -389,11 +362,11 @@ class MessageConverter(IDConverter[discord.Message]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Message:
guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
message_id, channel_id = PartialMessageConverter._get_id_matches(argument)
message = ctx.bot._connection._get_message(message_id)
if message:
return message
channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id)
channel = ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
if not channel:
raise ChannelNotFound(channel_id)
try:
@ -420,20 +393,19 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel:
return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel)
return self._resolve_channel(ctx, argument, ctx.guild.channels, discord.abc.GuildChannel)
@staticmethod
def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT:
def _resolve_channel(ctx: Context, argument: str, iterable: Iterable[CT], type: Type[CT]) -> CT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
iterable: Iterable[CT] = getattr(guild, attribute)
result: Optional[CT] = discord.utils.get(iterable, name=argument)
else:
@ -446,36 +418,13 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]):
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, "get_channel", channel_id)
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, type):
raise ChannelNotFound(argument)
return result
@staticmethod
def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT:
bot = ctx.bot
match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
iterable: Iterable[TT] = getattr(guild, attribute)
result: Optional[TT] = discord.utils.get(iterable, name=argument)
else:
thread_id = int(match.group(1))
if guild:
result = guild.get_thread(thread_id)
if not result or not isinstance(result, type):
raise ThreadNotFound(argument)
return result
class TextChannelConverter(IDConverter[discord.TextChannel]):
"""Converts to a :class:`~discord.TextChannel`.
@ -494,7 +443,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.TextChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, ctx.guild.text_channels, discord.TextChannel)
class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
@ -514,7 +463,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, ctx.guild.voice_channels, discord.VoiceChannel)
class StageChannelConverter(IDConverter[discord.StageChannel]):
@ -533,7 +482,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StageChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, ctx.guild.stage_channels, discord.StageChannel)
class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
@ -553,7 +502,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel)
return GuildChannelConverter._resolve_channel(ctx, argument, ctx.guild.categories, discord.CategoryChannel)
class StoreChannelConverter(IDConverter[discord.StoreChannel]):
@ -572,25 +521,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel:
return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel)
class ThreadConverter(IDConverter[discord.Thread]):
"""Coverts to a :class:`~discord.Thread`.
All lookups are via the local guild.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name.
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context, argument: str) -> discord.Thread:
return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread)
return GuildChannelConverter._resolve_channel(ctx, argument, ctx.guild.channels, discord.StoreChannel)
class ColourConverter(Converter[discord.Colour]):
@ -619,10 +550,10 @@ class ColourConverter(Converter[discord.Colour]):
Added support for ``rgb`` function and 3-digit hex shortcuts
"""
RGB_REGEX = re.compile(r"rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)")
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
def parse_hex_number(self, argument):
arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
if not (0 <= value <= 0xFFFFFF):
@ -633,7 +564,7 @@ class ColourConverter(Converter[discord.Colour]):
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
if number[-1] == "%":
if number[-1] == '%':
value = int(number[:-1])
if not (0 <= value <= 100):
raise BadColourArgument(argument)
@ -649,29 +580,29 @@ class ColourConverter(Converter[discord.Colour]):
if match is None:
raise BadColourArgument(argument)
red = self.parse_rgb_number(argument, match.group("r"))
green = self.parse_rgb_number(argument, match.group("g"))
blue = self.parse_rgb_number(argument, match.group("b"))
red = self.parse_rgb_number(argument, match.group('r'))
green = self.parse_rgb_number(argument, match.group('g'))
blue = self.parse_rgb_number(argument, match.group('b'))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx: Context, argument: str) -> discord.Colour:
if argument[0] == "#":
if argument[0] == '#':
return self.parse_hex_number(argument[1:])
if argument[0:2] == "0x":
if argument[0:2] == '0x':
rest = argument[2:]
# Legacy backwards compatible syntax
if rest.startswith("#"):
if rest.startswith('#'):
return self.parse_hex_number(rest[1:])
return self.parse_hex_number(rest)
arg = argument.lower()
if arg[0:3] == "rgb":
if arg[0:3] == 'rgb':
return self.parse_rgb(arg)
arg = arg.replace(" ", "_")
arg = arg.replace(' ', '_')
method = getattr(discord.Colour, arg, None)
if arg.startswith("from_") or method is None or not inspect.ismethod(method):
if arg.startswith('from_') or method is None or not inspect.ismethod(method):
raise BadColourArgument(arg)
return method()
@ -700,7 +631,7 @@ class RoleConverter(IDConverter[discord.Role]):
if not guild:
raise NoPrivateMessage()
match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument)
if match:
result = guild.get_role(int(match.group(1)))
else:
@ -732,7 +663,7 @@ class InviteConverter(Converter[discord.Invite]):
invite = await ctx.bot.fetch_invite(argument)
return invite
except Exception as exc:
raise BadInviteArgument(argument) from exc
raise BadInviteArgument() from exc
class GuildConverter(IDConverter[discord.Guild]):
@ -779,7 +710,7 @@ class EmojiConverter(IDConverter[discord.Emoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.Emoji:
match = self._get_id_match(argument) or re.match(r"<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$", argument)
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]{1,32}:([0-9]{15,20})>$', argument)
result = None
bot = ctx.bot
guild = ctx.guild
@ -795,7 +726,11 @@ class EmojiConverter(IDConverter[discord.Emoji]):
emoji_id = int(match.group(1))
# Try to look up emoji by id.
result = bot.get_emoji(emoji_id)
if guild:
result = discord.utils.get(guild.emojis, id=emoji_id)
if result is None:
result = discord.utils.get(bot.emojis, id=emoji_id)
if result is None:
raise EmojiNotFound(argument)
@ -813,7 +748,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
"""
async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji:
match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument)
match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument)
if match:
emoji_animated = bool(match.group(1))
@ -827,45 +762,6 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]):
raise PartialEmojiConversionFailure(argument)
class GuildStickerConverter(IDConverter[discord.GuildSticker]):
"""Converts to a :class:`~discord.GuildSticker`.
All lookups are done for the local guild first, if available. If that lookup
fails, then it checks the client's global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
3. Lookup by name
.. versionadded:: 2.0
"""
async def convert(self, ctx: Context, argument: str) -> discord.GuildSticker:
match = self._get_id_match(argument)
result = None
bot = ctx.bot
guild = ctx.guild
if match is None:
# Try to get the sticker by name. Try local guild first.
if guild:
result = discord.utils.get(guild.stickers, name=argument)
if result is None:
result = discord.utils.get(bot.stickers, name=argument)
else:
sticker_id = int(match.group(1))
# Try to look up sticker by id.
result = bot.get_sticker(sticker_id)
if result is None:
raise GuildStickerNotFound(argument)
return result
class clean_content(Converter[str]):
"""Converts the argument to mention scrubbed version of
said content.
@ -886,66 +782,67 @@ class clean_content(Converter[str]):
.. versionadded:: 1.7
"""
def __init__(
self,
*,
fix_channel_mentions: bool = False,
use_nicknames: bool = True,
escape_markdown: bool = False,
remove_markdown: bool = False,
) -> None:
def __init__(self, *, fix_channel_mentions: bool = False, use_nicknames: bool = True, escape_markdown: bool = False, remove_markdown: bool = False) -> None:
self.fix_channel_mentions = fix_channel_mentions
self.use_nicknames = use_nicknames
self.escape_markdown = escape_markdown
self.remove_markdown = remove_markdown
async def convert(self, ctx: Context, argument: str) -> str:
msg = ctx.message
if ctx.guild:
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id)
return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str:
r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id)
return f"@{r.name}" if r else "@deleted-role"
else:
def resolve_member(id: int) -> str:
m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id)
return f"@{m.name}" if m else "@deleted-user"
def resolve_role(id: int) -> str:
return "@deleted-role"
message = ctx.message
transformations = {}
if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id: int) -> str:
c = ctx.guild.get_channel(id)
return f"#{c.name}" if c else "#deleted-channel"
def resolve_channel(id, *, _get=ctx.guild.get_channel):
ch = _get(id)
return f'<#{id}>', ('#' + ch.name if ch else '#deleted-channel')
transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions)
if self.use_nicknames and ctx.guild:
def resolve_member(id, *, _get=ctx.guild.get_member):
m = _get(id)
return '@' + m.display_name if m else '@deleted-user'
else:
def resolve_channel(id: int) -> str:
return f"<#{id}>"
def resolve_member(id, *, _get=ctx.bot.get_user):
m = _get(id)
return '@' + m.name if m else '@deleted-user'
transforms = {
"@": resolve_member,
"@!": resolve_member,
"#": resolve_channel,
"@&": resolve_role,
}
# fmt: off
transformations.update(
(f'<@{member_id}>', resolve_member(member_id))
for member_id in message.raw_mentions
)
def repl(match: re.Match) -> str:
type = match[1]
id = int(match[2])
transformed = transforms[type](id)
return transformed
transformations.update(
(f'<@!{member_id}>', resolve_member(member_id))
for member_id in message.raw_mentions
)
# fmt: on
if ctx.guild:
def resolve_role(_id, *, _find=ctx.guild.get_role):
r = _find(_id)
return '@' + r.name if r else '@deleted-role'
# fmt: off
transformations.update(
(f'<@&{role_id}>', resolve_role(role_id))
for role_id in message.raw_role_mentions
)
# fmt: on
def repl(obj):
return transformations.get(obj.group(0), '')
pattern = re.compile('|'.join(transformations.keys()))
result = pattern.sub(repl, argument)
result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument)
if self.escape_markdown:
result = discord.utils.escape_markdown(result)
elif self.remove_markdown:
@ -977,89 +874,42 @@ class Greedy(List[T]):
For more information, check :ref:`ext_commands_special_converters`.
"""
__slots__ = ("converter",)
__slots__ = ('converter',)
def __init__(self, *, converter: T):
self.converter = converter
def __repr__(self):
converter = getattr(self.converter, "__name__", repr(self.converter))
return f"Greedy[{converter}]"
converter = getattr(self.converter, '__name__', repr(self.converter))
return f'Greedy[{converter}]'
def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]:
if not isinstance(params, tuple):
params = (params,)
if len(params) != 1:
raise TypeError("Greedy[...] only takes a single argument")
raise TypeError('Greedy[...] only takes a single argument')
converter = params[0]
origin = getattr(converter, "__origin__", None)
args = getattr(converter, "__args__", ())
origin = getattr(converter, '__origin__', None)
args = getattr(converter, '__args__', ())
if not (callable(converter) or isinstance(converter, Converter) or origin is not None):
raise TypeError("Greedy[...] expects a type or a Converter instance.")
raise TypeError('Greedy[...] expects a type or a Converter instance.')
if converter in (str, type(None)) or origin is Greedy:
raise TypeError(f"Greedy[{converter.__name__}] is invalid.")
raise TypeError(f'Greedy[{converter.__name__}] is invalid.')
if origin is Union and type(None) in args:
raise TypeError(f"Greedy[{converter!r}] is invalid.")
raise TypeError(f'Greedy[{converter!r}] is invalid.')
return cls(converter=converter)
class Option(Generic[T, DT]): # type: ignore
"""A special 'converter' to apply a description to slash command options.
For example in the following code:
.. code-block:: python3
@bot.command()
async def ban(ctx,
member: discord.Member, *,
reason: str = commands.Option('no reason', description='the reason to ban this member')
):
await member.ban(reason=reason)
The description would be ``the reason to ban this member`` and the default would be ``no reason``
.. versionadded:: 2.0
Attributes
------------
default: Optional[Any]
The default for this option, overwrites Option during parsing.
description: :class:`str`
The description for this option, is unpacked to :attr:`.Command.option_descriptions`
name: :class:`str`
The name of the option. This defaults to the parameter name.
"""
description: DT
default: Union[T, inspect._empty]
__slots__ = (
"default",
"description",
"name",
)
def __init__(
self, default: T = inspect.Parameter.empty, *, description: DT, name: str = discord.utils.MISSING
) -> None:
self.description = description
self.default = default
self.name: str = name
Option: Any
def _convert_to_bool(argument: str) -> bool:
lowered = argument.lower()
if lowered in ("yes", "y", "true", "t", "1", "enable", "on"):
if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
return True
elif lowered in ("no", "n", "false", "f", "0", "disable", "off"):
elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'):
return False
else:
raise BadBoolArgument(lowered)
@ -1100,9 +950,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = {
discord.PartialEmoji: PartialEmojiConverter,
discord.CategoryChannel: CategoryChannelConverter,
discord.StoreChannel: StoreChannelConverter,
discord.Thread: ThreadConverter,
discord.abc.GuildChannel: GuildChannelConverter,
discord.GuildSticker: GuildStickerConverter,
}
@ -1115,7 +963,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp
except AttributeError:
pass
else:
if module is not None and (module.startswith("discord.") and not module.endswith("converter")):
if module is not None and (module.startswith('discord.') and not module.endswith('converter')):
converter = CONVERTER_MAPPING.get(converter, converter)
try:
@ -1174,7 +1022,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect.
Any
The resulting conversion.
"""
origin = getattr(converter, "__origin__", None)
origin = getattr(converter, '__origin__', None)
if origin is Union:
errors = []

View File

@ -22,10 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Callable, Deque, Dict, Optional, Type, TypeVar, TYPE_CHECKING
from discord.enums import Enum
import time
import asyncio
@ -34,31 +30,24 @@ from collections import deque
from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached
if TYPE_CHECKING:
from ...message import Message
__all__ = (
"BucketType",
"Cooldown",
"CooldownMapping",
"DynamicCooldownMapping",
"MaxConcurrency",
'BucketType',
'Cooldown',
'CooldownMapping',
'DynamicCooldownMapping',
'MaxConcurrency',
)
C = TypeVar("C", bound="CooldownMapping")
MC = TypeVar("MC", bound="MaxConcurrency")
class BucketType(Enum):
default = 0
user = 1
guild = 2
channel = 3
member = 4
default = 0
user = 1
guild = 2
channel = 3
member = 4
category = 5
role = 6
role = 6
def get_key(self, msg: Message) -> Any:
def get_key(self, msg):
if self is BucketType.user:
return msg.author.id
elif self is BucketType.guild:
@ -68,52 +57,29 @@ class BucketType(Enum):
elif self is BucketType.member:
return ((msg.guild and msg.guild.id), msg.author.id)
elif self is BucketType.category:
return (msg.channel.category or msg.channel).id # type: ignore
return (msg.channel.category or msg.channel).id
elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
def __call__(self, msg: Message) -> Any:
def __call__(self, msg):
return self.get_key(msg)
class Cooldown:
"""Represents a cooldown for a command.
__slots__ = ('rate', 'per', '_window', '_tokens', '_last')
Attributes
-----------
rate: :class:`int`
The total number of tokens available per :attr:`per` seconds.
per: :class:`float`
The length of the cooldown period in seconds.
"""
def __init__(self, rate, per):
self.rate = int(rate)
self.per = float(per)
self._window = 0.0
self._tokens = self.rate
self._last = 0.0
__slots__ = ("rate", "per", "_window", "_tokens", "_last")
def __init__(self, rate: float, per: float) -> None:
self.rate: int = int(rate)
self.per: float = float(per)
self._window: float = 0.0
self._tokens: int = self.rate
self._last: float = 0.0
def get_tokens(self, current: Optional[float] = None) -> int:
"""Returns the number of available tokens before rate limiting is applied.
Parameters
------------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to calculate tokens at.
If not supplied then :func:`time.time()` is used.
Returns
--------
:class:`int`
The number of tokens available before the cooldown is to be applied.
"""
def get_tokens(self, current=None):
if not current:
current = time.time()
@ -123,20 +89,7 @@ class Cooldown:
tokens = self.rate
return tokens
def get_retry_after(self, current: Optional[float] = None) -> float:
"""Returns the time in seconds until the cooldown will be reset.
Parameters
-------------
current: Optional[:class:`float`]
The current time in seconds since Unix epoch.
If not supplied, then :func:`time.time()` is used.
Returns
-------
:class:`float`
The number of seconds to wait before this cooldown will be reset.
"""
def get_retry_after(self, current=None):
current = current or time.time()
tokens = self.get_tokens(current)
@ -145,20 +98,7 @@ class Cooldown:
return 0.0
def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]:
"""Updates the cooldown rate limit.
Parameters
-------------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to update the rate limit at.
If not supplied, then :func:`time.time()` is used.
Returns
-------
Optional[:class:`float`]
The retry-after time in seconds if rate limited.
"""
def update_rate_limit(self, current=None):
current = current or time.time()
self._last = current
@ -175,59 +115,47 @@ class Cooldown:
# we're not so decrement our tokens
self._tokens -= 1
def reset(self) -> None:
"""Reset the cooldown to its initial state."""
# see if we got rate limited due to this token change, and if
# so update the window to point to our current time frame
if self._tokens == 0:
self._window = current
def reset(self):
self._tokens = self.rate
self._last = 0.0
def copy(self) -> Cooldown:
"""Creates a copy of this cooldown.
Returns
--------
:class:`Cooldown`
A new instance of this cooldown.
"""
def copy(self):
return Cooldown(self.rate, self.per)
def __repr__(self) -> str:
return f"<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>"
def __repr__(self):
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
class CooldownMapping:
def __init__(
self,
original: Optional[Cooldown],
type: Callable[[Message], Any],
) -> None:
def __init__(self, original, type):
if not callable(type):
raise TypeError("Cooldown type must be a BucketType or callable")
raise TypeError('Cooldown type must be a BucketType or callable')
self._cache: Dict[Any, Cooldown] = {}
self._cooldown: Optional[Cooldown] = original
self._type: Callable[[Message], Any] = type
self._cache = {}
self._cooldown = original
self._type = type
def copy(self) -> CooldownMapping:
def copy(self):
ret = CooldownMapping(self._cooldown, self._type)
ret._cache = self._cache.copy()
return ret
@property
def valid(self) -> bool:
def valid(self):
return self._cooldown is not None
@property
def type(self) -> Callable[[Message], Any]:
return self._type
@classmethod
def from_cooldown(cls: Type[C], rate, per, type) -> C:
def from_cooldown(cls, rate, per, type):
return cls(Cooldown(rate, per), type)
def _bucket_key(self, msg: Message) -> Any:
def _bucket_key(self, msg):
return self._type(msg)
def _verify_cache_integrity(self, current: Optional[float] = None) -> None:
def _verify_cache_integrity(self, current=None):
# we want to delete all cache objects that haven't been used
# in a cooldown window. e.g. if we have a command that has a
# cooldown of 60s and it has not been used in 60s then that key should be deleted
@ -236,12 +164,12 @@ class CooldownMapping:
for k in dead_keys:
del self._cache[k]
def create_bucket(self, message: Message) -> Cooldown:
return self._cooldown.copy() # type: ignore
def create_bucket(self, message):
return self._cooldown.copy()
def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown:
def get_bucket(self, message, current=None):
if self._type is BucketType.default:
return self._cooldown # type: ignore
return self._cooldown
self._verify_cache_integrity(current)
key = self._bucket_key(message)
@ -254,29 +182,28 @@ class CooldownMapping:
return bucket
def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]:
def update_rate_limit(self, message, current=None):
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
def copy(self) -> DynamicCooldownMapping:
def __init__(self, factory, type):
super().__init__(None, type)
self._factory = factory
def copy(self):
ret = DynamicCooldownMapping(self._factory, self._type)
ret._cache = self._cache.copy()
return ret
@property
def valid(self) -> bool:
def valid(self):
return True
def create_bucket(self, message: Message) -> Cooldown:
def create_bucket(self, message):
return self._factory(message)
class _Semaphore:
"""This class is a version of a semaphore.
@ -290,30 +217,30 @@ class _Semaphore:
overkill for what is basically a counter.
"""
__slots__ = ("value", "loop", "_waiters")
__slots__ = ('value', 'loop', '_waiters')
def __init__(self, number: int) -> None:
self.value: int = number
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
self._waiters: Deque[asyncio.Future] = deque()
def __init__(self, number):
self.value = number
self.loop = asyncio.get_event_loop()
self._waiters = deque()
def __repr__(self) -> str:
return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>"
def __repr__(self):
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>'
def locked(self) -> bool:
def locked(self):
return self.value == 0
def is_active(self) -> bool:
def is_active(self):
return len(self._waiters) > 0
def wake_up(self) -> None:
def wake_up(self):
while self._waiters:
future = self._waiters.popleft()
if not future.done():
future.set_result(None)
return
async def acquire(self, *, wait: bool = False) -> bool:
async def acquire(self, *, wait=False):
if not wait and self.value <= 0:
# signal that we're not acquiring
return False
@ -332,36 +259,35 @@ class _Semaphore:
self.value -= 1
return True
def release(self) -> None:
def release(self):
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ("number", "per", "wait", "_mapping")
__slots__ = ('number', 'per', 'wait', '_mapping')
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None:
self._mapping: Dict[Any, _Semaphore] = {}
self.per: BucketType = per
self.number: int = number
self.wait: bool = wait
def __init__(self, number, *, per, wait):
self._mapping = {}
self.per = per
self.number = number
self.wait = wait
if number <= 0:
raise ValueError("max_concurrency 'number' cannot be less than 1")
raise ValueError('max_concurrency \'number\' cannot be less than 1')
if not isinstance(per, BucketType):
raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}")
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
def copy(self: MC) -> MC:
def copy(self):
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str:
return f"<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>"
def __repr__(self):
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>'
def get_key(self, message: Message) -> Any:
def get_key(self, message):
return self.per.get_key(message)
async def acquire(self, message: Message) -> None:
async def acquire(self, message):
key = self.get_key(message)
try:
@ -373,7 +299,7 @@ class MaxConcurrency:
if not acquired:
raise MaxConcurrencyReached(self.number, self.per)
async def release(self, message: Message) -> None:
async def release(self, message):
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_key(message)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -59,9 +59,9 @@ import sys
import re
__all__ = (
"Flag",
"flag",
"FlagConverter",
'Flag',
'flag',
'FlagConverter',
)
@ -81,8 +81,6 @@ class Flag:
------------
name: :class:`str`
The name of the flag.
description: :class:`str`
The description of the flag.
aliases: List[:class:`str`]
The aliases of the flag name.
attribute: :class:`str`
@ -99,7 +97,6 @@ class Flag:
"""
name: str = MISSING
description: str = MISSING
aliases: List[str] = field(default_factory=list)
attribute: str = MISSING
annotation: Any = MISSING
@ -120,7 +117,6 @@ class Flag:
def flag(
*,
name: str = MISSING,
description: str = MISSING,
aliases: List[str] = MISSING,
default: Any = MISSING,
max_args: int = MISSING,
@ -133,8 +129,6 @@ def flag(
------------
name: :class:`str`
The flag name. If not given, defaults to the attribute name.
description: :class:`str`
Description of the flag for the slash commands options. The default value is `'no description'`.
aliases: List[:class:`str`]
Aliases to the flag name. If not given no aliases are set.
default: Any
@ -149,27 +143,25 @@ def flag(
Whether multiple given values overrides the previous value. The default
value depends on the annotation given.
"""
return Flag(
name=name, description=description, aliases=aliases, default=default, max_args=max_args, override=override
)
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]):
if not name:
raise ValueError("flag names should not be empty")
raise ValueError('flag names should not be empty')
for ch in name:
if ch.isspace():
raise ValueError(f"flag name {name!r} cannot have spaces")
if ch == "\\":
raise ValueError(f"flag name {name!r} cannot have backslashes")
raise ValueError(f'flag name {name!r} cannot have spaces')
if ch == '\\':
raise ValueError(f'flag name {name!r} cannot have backslashes')
if ch in forbidden:
raise ValueError(f"flag name {name!r} cannot have any of {forbidden!r} within them")
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
annotations = namespace.get("__annotations__", {})
case_insensitive = namespace["__commands_flag_case_insensitive__"]
annotations = namespace.get('__annotations__', {})
case_insensitive = namespace['__commands_flag_case_insensitive__']
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
@ -186,11 +178,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
if (
flag.default is MISSING
and hasattr(annotation, "__commands_is_flag__")
and annotation._can_be_constructible()
):
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
flag.default = annotation._construct_default
if flag.aliases is MISSING:
@ -241,7 +229,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
if flag.max_args is MISSING:
flag.max_args = 1
else:
raise TypeError(f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag")
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
if flag.override is MISSING:
flag.override = False
@ -249,7 +237,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name
if name in names:
raise TypeError(f"{flag.name!r} flag conflicts with previous flag or alias.")
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
else:
names.add(name)
@ -257,7 +245,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
# Validate alias is unique
alias = alias.casefold() if case_insensitive else alias
if alias in names:
raise TypeError(f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.")
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
else:
names.add(alias)
@ -286,10 +274,10 @@ class FlagsMeta(type):
delimiter: str = MISSING,
prefix: str = MISSING,
):
attrs["__commands_is_flag__"] = True
attrs['__commands_is_flag__'] = True
try:
global_ns = sys.modules[attrs["__module__"]].__dict__
global_ns = sys.modules[attrs['__module__']].__dict__
except KeyError:
global_ns = {}
@ -308,26 +296,26 @@ class FlagsMeta(type):
flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {}
for base in reversed(bases):
if base.__dict__.get("__commands_is_flag__", False):
flags.update(base.__dict__["__commands_flags__"])
aliases.update(base.__dict__["__commands_flag_aliases__"])
if base.__dict__.get('__commands_is_flag__', False):
flags.update(base.__dict__['__commands_flags__'])
aliases.update(base.__dict__['__commands_flag_aliases__'])
if case_insensitive is MISSING:
attrs["__commands_flag_case_insensitive__"] = base.__dict__["__commands_flag_case_insensitive__"]
attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__']
if delimiter is MISSING:
attrs["__commands_flag_delimiter__"] = base.__dict__["__commands_flag_delimiter__"]
attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__']
if prefix is MISSING:
attrs["__commands_flag_prefix__"] = base.__dict__["__commands_flag_prefix__"]
attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__']
if case_insensitive is not MISSING:
attrs["__commands_flag_case_insensitive__"] = case_insensitive
attrs['__commands_flag_case_insensitive__'] = case_insensitive
if delimiter is not MISSING:
attrs["__commands_flag_delimiter__"] = delimiter
attrs['__commands_flag_delimiter__'] = delimiter
if prefix is not MISSING:
attrs["__commands_flag_prefix__"] = prefix
attrs['__commands_flag_prefix__'] = prefix
case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False)
delimiter = attrs.setdefault("__commands_flag_delimiter__", ":")
prefix = attrs.setdefault("__commands_flag_prefix__", "")
case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False)
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
prefix = attrs.setdefault('__commands_flag_prefix__', '')
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
@ -349,11 +337,11 @@ class FlagsMeta(type):
keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = "|".join(keys)
pattern = re.compile(f"(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})", regex_flags)
attrs["__commands_flag_regex__"] = pattern
attrs["__commands_flags__"] = flags
attrs["__commands_flag_aliases__"] = aliases
joined = '|'.join(keys)
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
attrs['__commands_flag_regex__'] = pattern
attrs['__commands_flags__'] = flags
attrs['__commands_flag_aliases__'] = aliases
return type.__new__(cls, name, bases, attrs)
@ -444,7 +432,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -
raise BadFlagArgument(flag) from e
F = TypeVar("F", bound="FlagConverter")
F = TypeVar('F', bound='FlagConverter')
class FlagConverter(metaclass=FlagsMeta):
@ -505,8 +493,8 @@ class FlagConverter(metaclass=FlagsMeta):
return self
def __repr__(self) -> str:
pairs = " ".join([f"{flag.attribute}={getattr(self, flag.attribute)!r}" for flag in self.get_flags().values()])
return f"<{self.__class__.__name__} {pairs}>"
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
return f'<{self.__class__.__name__} {pairs}>'
@classmethod
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
@ -519,7 +507,7 @@ class FlagConverter(metaclass=FlagsMeta):
case_insensitive = cls.__commands_flag_case_insensitive__
for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0)
key = match.group("flag")
key = match.group('flag')
if case_insensitive:
key = key.casefold()

View File

@ -27,22 +27,16 @@ import copy
import functools
import inspect
import re
from typing import Optional, TYPE_CHECKING
import discord.utils
from .core import Group, Command
from .errors import CommandError
if TYPE_CHECKING:
from .context import Context
__all__ = (
"Paginator",
"HelpCommand",
"DefaultHelpCommand",
"MinimalHelpCommand",
'Paginator',
'HelpCommand',
'DefaultHelpCommand',
'MinimalHelpCommand',
)
# help -> shows info of bot on top/bottom and lists subcommands
@ -89,7 +83,7 @@ class Paginator:
.. versionadded:: 1.7
"""
def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"):
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
self.prefix = prefix
self.suffix = suffix
self.max_size = max_size
@ -118,7 +112,7 @@ class Paginator:
def _linesep_len(self):
return len(self.linesep)
def add_line(self, line="", *, empty=False):
def add_line(self, line='', *, empty=False):
"""Adds a line to the current page.
If the line exceeds the :attr:`max_size` then an exception
@ -138,7 +132,7 @@ class Paginator:
"""
max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len
if len(line) > max_page_size:
raise RuntimeError(f"Line exceeds maximum page size {max_page_size}")
raise RuntimeError(f'Line exceeds maximum page size {max_page_size}')
if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len:
self.close_page()
@ -147,7 +141,7 @@ class Paginator:
self._current_page.append(line)
if empty:
self._current_page.append("")
self._current_page.append('')
self._count += self._linesep_len
def close_page(self):
@ -176,7 +170,7 @@ class Paginator:
return self._pages
def __repr__(self):
fmt = "<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>"
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
return fmt.format(self)
@ -197,7 +191,7 @@ class _HelpCommandImpl(Command):
self.callback = injected.command_callback
on_error = injected.on_help_command_error
if not hasattr(on_error, "__help_command_not_overriden__"):
if not hasattr(on_error, '__help_command_not_overriden__'):
if self.cog is not None:
self.on_error = self._on_error_cog_implementation
else:
@ -224,7 +218,7 @@ class _HelpCommandImpl(Command):
try:
del result[next(iter(result))]
except StopIteration:
raise ValueError("Missing context parameter") from None
raise ValueError('Missing context parameter') from None
else:
return result
@ -296,13 +290,13 @@ class HelpCommand:
"""
MENTION_TRANSFORMS = {
"@everyone": "@\u200beveryone",
"@here": "@\u200bhere",
r"<@!?[0-9]{17,22}>": "@deleted-user",
r"<@&[0-9]{17,22}>": "@deleted-role",
'@everyone': '@\u200beveryone',
'@here': '@\u200bhere',
r'<@!?[0-9]{17,22}>': '@deleted-user',
r'<@&[0-9]{17,22}>': '@deleted-role',
}
MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys()))
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
def __new__(cls, *args, **kwargs):
# To prevent race conditions of a single instance while also allowing
@ -321,12 +315,12 @@ class HelpCommand:
return self
def __init__(self, **options):
self.show_hidden = options.pop("show_hidden", False)
self.verify_checks = options.pop("verify_checks", True)
self.command_attrs = attrs = options.pop("command_attrs", {})
attrs.setdefault("name", "help")
attrs.setdefault("help", "Shows this message")
self.context: Context = discord.utils.MISSING
self.show_hidden = options.pop('show_hidden', False)
self.verify_checks = options.pop('verify_checks', True)
self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message')
self.context = None
self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
def copy(self):
@ -422,20 +416,20 @@ class HelpCommand:
if not parent.signature or parent.invoke_without_command:
entries.append(parent.name)
else:
entries.append(parent.name + " " + parent.signature)
entries.append(parent.name + ' ' + parent.signature)
parent = parent.parent
parent_sig = " ".join(reversed(entries))
parent_sig = ' '.join(reversed(entries))
if len(command.aliases) > 0:
aliases = "|".join(command.aliases)
fmt = f"[{command.name}|{aliases}]"
aliases = '|'.join(command.aliases)
fmt = f'[{command.name}|{aliases}]'
if parent_sig:
fmt = parent_sig + " " + fmt
fmt = parent_sig + ' ' + fmt
alias = fmt
else:
alias = command.name if not parent_sig else parent_sig + " " + command.name
alias = command.name if not parent_sig else parent_sig + ' ' + command.name
return f"{self.context.clean_prefix}{alias} {command.signature}"
return f'{self.context.clean_prefix}{alias} {command.signature}'
def remove_mentions(self, string):
"""Removes mentions from the string to prevent abuse.
@ -449,7 +443,7 @@ class HelpCommand:
"""
def replace(obj, *, transforms=self.MENTION_TRANSFORMS):
return transforms.get(obj.group(0), "@invalid")
return transforms.get(obj.group(0), '@invalid')
return self.MENTION_PATTERN.sub(replace, string)
@ -615,7 +609,7 @@ class HelpCommand:
:class:`.abc.Messageable`
The destination where the help command will be output.
"""
return self.context
return self.context.channel
async def send_error_message(self, error):
"""|coro|
@ -846,7 +840,7 @@ class HelpCommand:
# Since we want to have detailed errors when someone
# passes an invalid subcommand, we need to walk through
# the command group chain ourselves.
keys = command.split(" ")
keys = command.split(' ')
cmd = bot.all_commands.get(keys[0])
if cmd is None:
string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0]))
@ -907,14 +901,14 @@ class DefaultHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.width = options.pop("width", 80)
self.indent = options.pop("indent", 2)
self.sort_commands = options.pop("sort_commands", True)
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.commands_heading = options.pop("commands_heading", "Commands:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
self.width = options.pop('width', 80)
self.indent = options.pop('indent', 2)
self.sort_commands = options.pop('sort_commands', True)
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.commands_heading = options.pop('commands_heading', "Commands:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
if self.paginator is None:
self.paginator = Paginator()
@ -924,7 +918,7 @@ class DefaultHelpCommand(HelpCommand):
def shorten_text(self, text):
""":class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width:
return text[: self.width - 3].rstrip() + "..."
return text[:self.width - 3].rstrip() + '...'
return text
def get_ending_note(self):
@ -977,10 +971,6 @@ class DefaultHelpCommand(HelpCommand):
for page in self.paginator.pages:
await destination.send(page)
interaction = self.context.interaction
if interaction is not None and destination == self.context.author and not interaction.response.is_done():
await interaction.response.send_message("Sent help to your DMs!", ephemeral=True)
def add_command_formatting(self, command):
"""A utility function to format the non-indented block of commands and groups.
@ -1011,7 +1001,7 @@ class DefaultHelpCommand(HelpCommand):
elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold:
return ctx.author
else:
return ctx
return ctx.channel
async def prepare_help_command(self, ctx, command):
self.paginator.clear()
@ -1025,11 +1015,11 @@ class DefaultHelpCommand(HelpCommand):
# <description> portion
self.paginator.add_line(bot.description, empty=True)
no_category = f"\u200b{self.no_category}:"
no_category = f'\u200b{self.no_category}:'
def get_category(command, *, no_category=no_category):
cog = command.cog
return cog.qualified_name + ":" if cog is not None else no_category
return cog.qualified_name + ':' if cog is not None else no_category
filtered = await self.filter_commands(bot.commands, sort=True, key=get_category)
max_size = self.get_max_size(filtered)
@ -1114,13 +1104,13 @@ class MinimalHelpCommand(HelpCommand):
"""
def __init__(self, **options):
self.sort_commands = options.pop("sort_commands", True)
self.commands_heading = options.pop("commands_heading", "Commands")
self.dm_help = options.pop("dm_help", False)
self.dm_help_threshold = options.pop("dm_help_threshold", 1000)
self.aliases_heading = options.pop("aliases_heading", "Aliases:")
self.no_category = options.pop("no_category", "No Category")
self.paginator = options.pop("paginator", None)
self.sort_commands = options.pop('sort_commands', True)
self.commands_heading = options.pop('commands_heading', "Commands")
self.dm_help = options.pop('dm_help', False)
self.dm_help_threshold = options.pop('dm_help_threshold', 1000)
self.aliases_heading = options.pop('aliases_heading', "Aliases:")
self.no_category = options.pop('no_category', 'No Category')
self.paginator = options.pop('paginator', None)
if self.paginator is None:
self.paginator = Paginator(suffix=None, prefix=None)
@ -1153,7 +1143,7 @@ class MinimalHelpCommand(HelpCommand):
)
def get_command_signature(self, command):
return f"{self.context.clean_prefix}{command.qualified_name} {command.signature}"
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}'
def get_ending_note(self):
"""Return the help command's ending note. This is mainly useful to override for i18n purposes.
@ -1184,8 +1174,8 @@ class MinimalHelpCommand(HelpCommand):
"""
if commands:
# U+2002 Middle Dot
joined = "\u2002".join(c.name for c in commands)
self.paginator.add_line(f"__**{heading}**__")
joined = '\u2002'.join(c.name for c in commands)
self.paginator.add_line(f'__**{heading}**__')
self.paginator.add_line(joined)
def add_subcommand_formatting(self, command):
@ -1201,7 +1191,7 @@ class MinimalHelpCommand(HelpCommand):
command: :class:`Command`
The command to show information of.
"""
fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}"
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc))
def add_aliases_formatting(self, aliases):
@ -1272,7 +1262,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
no_category = f"\u200b{self.no_category}"
no_category = f'\u200b{self.no_category}'
def get_category(command, *, no_category=no_category):
cog = command.cog
@ -1306,7 +1296,7 @@ class MinimalHelpCommand(HelpCommand):
filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands)
if filtered:
self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**")
self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**')
for command in filtered:
self.add_subcommand_formatting(command)
@ -1326,7 +1316,7 @@ class MinimalHelpCommand(HelpCommand):
if note:
self.paginator.add_line(note, empty=True)
self.paginator.add_line(f"**{self.commands_heading}**")
self.paginator.add_line(f'**{self.commands_heading}**')
for command in filtered:
self.add_subcommand_formatting(command)

View File

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes
supported_quotes = {
_quotes = {
'"': '"',
"": "",
"": "",
@ -44,8 +44,7 @@ supported_quotes = {
"": "",
"": "",
}
_all_quotes = set(supported_quotes.keys()) | set(supported_quotes.values())
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView:
def __init__(self, buffer):
@ -82,20 +81,20 @@ class StringView:
def skip_string(self, string):
strlen = len(string)
if self.buffer[self.index : self.index + strlen] == string:
if self.buffer[self.index:self.index + strlen] == string:
self.previous = self.index
self.index += strlen
return True
return False
def read_rest(self):
result = self.buffer[self.index :]
result = self.buffer[self.index:]
self.previous = self.index
self.index = self.end
return result
def read(self, n):
result = self.buffer[self.index : self.index + n]
result = self.buffer[self.index:self.index + n]
self.previous = self.index
self.index += n
return result
@ -121,7 +120,7 @@ class StringView:
except IndexError:
break
self.previous = self.index
result = self.buffer[self.index : self.index + pos]
result = self.buffer[self.index:self.index + pos]
self.index += pos
return result
@ -130,7 +129,7 @@ class StringView:
if current is None:
return None
close_quote = supported_quotes.get(current)
close_quote = _quotes.get(current)
is_quoted = bool(close_quote)
if is_quoted:
result = []
@ -145,11 +144,11 @@ class StringView:
if is_quoted:
# unexpected EOF
raise ExpectedClosingQuoteError(close_quote)
return "".join(result)
return ''.join(result)
# currently we accept strings in the format of "hello world"
# to embed a quote inside the string you must escape it: "a \"world\""
if current == "\\":
if current == '\\':
next_char = self.get()
if not next_char:
# string ends with \ and no character after it
@ -157,7 +156,7 @@ class StringView:
# if we're quoted then we're expecting a closing quote
raise ExpectedClosingQuoteError(close_quote)
# if we aren't then we just let it through
return "".join(result)
return ''.join(result)
if next_char in _escaped_quotes:
# escaped quote
@ -180,13 +179,14 @@ class StringView:
raise InvalidEndOfQuotedStringError(next_char)
# we're quoted so it's okay
return "".join(result)
return ''.join(result)
if current.isspace() and not is_quoted:
# end of word found
return "".join(result)
return ''.join(result)
result.append(current)
def __repr__(self):
return f"<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>"
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>'

View File

@ -27,20 +27,24 @@ from __future__ import annotations
import asyncio
import datetime
from typing import (
Any,
Awaitable,
Any,
Awaitable,
Callable,
Coroutine,
Generic,
List,
List,
Optional,
Type,
TYPE_CHECKING,
Type,
TypeVar,
Union,
cast,
)
import aiohttp
import discord
import inspect
import logging
import sys
import traceback
@ -48,17 +52,30 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
__all__ = ("loop",)
if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec
T = TypeVar("T")
P = ParamSpec("P")
else:
P = TypeVar("P") # hacky runtime fix
log = logging.getLogger(__name__)
__all__ = (
'loop',
)
C = TypeVar('C')
T = TypeVar('T')
_coro = Coroutine[Any, Any, T]
_func = Callable[..., Awaitable[Any]]
LF = TypeVar("LF", bound=_func)
FT = TypeVar("FT", bound=_func)
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LT = TypeVar('LT', bound='Loop')
class SleepHandle:
__slots__ = ("future", "loop", "handle")
__slots__ = ('future', 'loop', 'handle')
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop
@ -71,7 +88,7 @@ class SleepHandle:
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future[Any]:
def wait(self) -> asyncio.Future:
return self.future
def done(self) -> bool:
@ -82,31 +99,29 @@ class SleepHandle:
self.future.cancel()
class Loop(Generic[LF]):
class Loop(Generic[C, P, T]):
"""A background task helper that abstracts the loop and reconnection logic for you.
The main interface to create this is through :func:`loop`.
"""
def __init__(
self,
coro: LF,
def __init__(self,
coro: Callable[P, _coro[T]],
seconds: float,
hours: float,
minutes: float,
time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int],
reconnect: bool,
loop: asyncio.AbstractEventLoop,
loop: Optional[asyncio.AbstractEventLoop],
) -> None:
self.coro: LF = coro
self.coro: Callable[P, _coro[T]] = coro
self.reconnect: bool = reconnect
self.loop: asyncio.AbstractEventLoop = loop
self.loop: Optional[asyncio.AbstractEventLoop] = loop
self.count: Optional[int] = count
self._current_loop = 0
self._handle: SleepHandle = MISSING
self._task: asyncio.Task[None] = MISSING
self._injected = None
self._handle = None
self._task = None
self._injected: Optional[C] = None
self._valid_exception = (
OSError,
discord.GatewayNotFound,
@ -122,18 +137,18 @@ class Loop(Generic[LF]):
self._stop_next_iteration = False
if self.count is not None and self.count <= 0:
raise ValueError("count must be greater than 0 or None.")
raise ValueError('count must be greater than 0 or None.')
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False
self._last_iteration: datetime.datetime = MISSING
self._last_iteration = None
self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro):
raise TypeError(f"Expected coroutine function, not {type(self.coro).__name__!r}.")
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
coro = getattr(self, "_" + name)
coro = getattr(self, '_' + name)
if coro is None:
return
@ -142,13 +157,14 @@ class Loop(Generic[LF]):
else:
await coro(*args, **kwargs)
def _try_sleep_until(self, dt: datetime.datetime):
self._handle = SleepHandle(dt=dt, loop=self.loop)
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore
return self._handle.wait()
async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff()
await self._call_loop_function("before_loop")
await self._call_loop_function('before_loop')
self._last_iteration_failed = False
if self._time is not MISSING:
# the time index should be prepared every time the internal loop is started
@ -172,7 +188,7 @@ class Loop(Generic[LF]):
await asyncio.sleep(backoff.delay())
else:
await self._try_sleep_until(self._next_iteration)
if self._stop_next_iteration:
return
@ -191,28 +207,28 @@ class Loop(Generic[LF]):
raise
except Exception as exc:
self._has_failed = True
await self._call_loop_function("error", exc)
await self._call_loop_function('error', exc)
raise exc
finally:
await self._call_loop_function("after_loop")
await self._call_loop_function('after_loop')
self._handle.cancel()
self._is_being_cancelled = False
self._current_loop = 0
self._stop_next_iteration = False
self._has_failed = False
def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]:
def __get__(self, obj: C, objtype: Type[C]) -> Loop[C, P, T]:
if obj is None:
return self
copy: Loop[LF] = Loop(
self.coro,
seconds=self._seconds,
hours=self._hours,
copy = Loop[C, P, T](
self.coro,
seconds=self._seconds,
hours=self._hours,
minutes=self._minutes,
time=self._time,
time=self._time,
count=self.count,
reconnect=self.reconnect,
reconnect=self.reconnect,
loop=self.loop,
)
copy._injected = obj
@ -231,7 +247,7 @@ class Loop(Generic[LF]):
"""
if self._seconds is not MISSING:
return self._seconds
@property
def minutes(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of minutes
@ -241,7 +257,7 @@ class Loop(Generic[LF]):
"""
if self._minutes is not MISSING:
return self._minutes
@property
def hours(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of hours
@ -273,13 +289,13 @@ class Loop(Generic[LF]):
.. versionadded:: 1.3
"""
if self._task is MISSING:
if self._task is None:
return None
elif self._task and self._task.done() or self._stop_next_iteration:
return None
return self._next_iteration
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
r"""|coro|
Calls the internal callback that the task holds.
@ -299,7 +315,7 @@ class Loop(Generic[LF]):
return await self.coro(*args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
def start(self, *args: P.args, **kwargs: P.kwargs) -> asyncio.Task:
r"""Starts the internal task in the event loop.
Parameters
@ -320,13 +336,13 @@ class Loop(Generic[LF]):
The task that has been created.
"""
if self._task is not MISSING and not self._task.done():
raise RuntimeError("Task is already launched and is not completed.")
if self._task is not None and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.')
if self._injected is not None:
args = (self._injected, *args)
if self.loop is MISSING:
if self.loop is None:
self.loop = asyncio.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs))
@ -350,7 +366,7 @@ class Loop(Generic[LF]):
.. versionadded:: 1.2
"""
if self._task is not MISSING and not self._task.done():
if self._task and not self._task.done():
self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool:
@ -361,7 +377,7 @@ class Loop(Generic[LF]):
if self._can_be_cancelled():
self._task.cancel()
def restart(self, *args: Any, **kwargs: Any) -> None:
def restart(self, *args: P.args, **kwargs: P.kwargs) -> None:
r"""A convenience method to restart the internal task.
.. note::
@ -372,12 +388,12 @@ class Loop(Generic[LF]):
Parameters
------------
\*args
The arguments to use.
The arguments to to use.
\*\*kwargs
The keyword arguments to use.
"""
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
def restart_when_over(fut, *, args=args, kwargs=kwargs):
self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs)
@ -408,9 +424,9 @@ class Loop(Generic[LF]):
for exc in exceptions:
if not inspect.isclass(exc):
raise TypeError(f"{exc!r} must be a class.")
raise TypeError(f'{exc!r} must be a class.')
if not issubclass(exc, BaseException):
raise TypeError(f"{exc!r} must inherit from BaseException.")
raise TypeError(f'{exc!r} must inherit from BaseException.')
self._valid_exception = (*self._valid_exception, *exceptions)
@ -440,9 +456,9 @@ class Loop(Generic[LF]):
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self) -> Optional[asyncio.Task[None]]:
def get_task(self) -> Optional[asyncio.Task]:
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
return self._task if self._task is not MISSING else None
return self._task
def is_being_cancelled(self) -> bool:
"""Whether the task is being cancelled."""
@ -460,11 +476,11 @@ class Loop(Generic[LF]):
.. versionadded:: 1.4
"""
return not bool(self._task.done()) if self._task is not MISSING else False
return not bool(self._task.done()) if self._task else False
async def _error(self, *args: Any) -> None:
exception: Exception = args[-1]
print(f"Unhandled exception in internal background task {self.coro.__name__!r}.", file=sys.stderr)
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
def before_loop(self, coro: FT) -> FT:
@ -487,7 +503,7 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._before_loop = coro
return coro
@ -515,7 +531,7 @@ class Loop(Generic[LF]):
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._after_loop = coro
return coro
@ -541,7 +557,7 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.")
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._error = coro # type: ignore
return coro
@ -554,9 +570,7 @@ class Loop(Generic[LF]):
self._time_index = 0
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
)
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0])
next_time = self._time[self._time_index]
@ -564,7 +578,7 @@ class Loop(Generic[LF]):
self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
next_date = self._last_iteration
next_date = cast(datetime.datetime, self._last_iteration)
if self._time_index == 0:
# we can assume that the earliest time should be scheduled for "tomorrow"
next_date += datetime.timedelta(days=1)
@ -572,14 +586,12 @@ class Loop(Generic[LF]):
self._time_index += 1
return datetime.datetime.combine(next_date, next_time)
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None:
# now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from
# pre-condition: self._time is set
time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz()
for idx, time in enumerate(self._time):
if time >= time_now:
self._time_index = idx
@ -595,24 +607,20 @@ class Loop(Generic[LF]):
utc: datetime.timezone = datetime.timezone.utc,
) -> List[datetime.time]:
if isinstance(time, dt):
inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [inner]
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [ret]
if not isinstance(time, Sequence):
raise TypeError(
f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead."
)
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
if not time:
raise ValueError("time parameter must not be an empty sequence.")
raise ValueError('time parameter must not be an empty sequence.')
ret: List[datetime.time] = []
ret = []
for index, t in enumerate(time):
if not isinstance(t, dt):
raise TypeError(
f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead."
)
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times
ret = sorted(set(ret)) # de-dupe and sort times
return ret
def change_interval(
@ -661,7 +669,7 @@ class Loop(Generic[LF]):
hours = hours or 0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0:
raise ValueError("Total number of seconds cannot be less than zero.")
raise ValueError('Total number of seconds cannot be less than zero.')
self._sleep = sleep
self._seconds = float(seconds)
@ -670,7 +678,7 @@ class Loop(Generic[LF]):
self._time: List[datetime.time] = MISSING
else:
if any((seconds, minutes, hours)):
raise TypeError("Cannot mix explicit time with relative time")
raise TypeError('Cannot mix explicit time with relative time')
self._time = self._get_time_parameter(time)
self._sleep = self._seconds = self._minutes = self._hours = MISSING
@ -693,8 +701,8 @@ def loop(
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None,
reconnect: bool = True,
loop: asyncio.AbstractEventLoop = MISSING,
) -> Callable[[LF], Loop[LF]]:
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Callable[[Union[Callable[Concatenate[Type[C], P], _coro[T]], Callable[P, _coro[T]]]], Loop[C, P, T]]:
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
@ -709,7 +717,7 @@ def loop(
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
The exact times to run this loop at. Either a non-empty list or a single
value of :class:`datetime.time` should be passed. Timezones are supported.
If no timezone is given for the times, it is assumed to represent UTC time.
If no timezone is given for the times, it is assumed to represent UTC time.
This cannot be used in conjunction with the relative time parameters.
@ -726,7 +734,7 @@ def loop(
Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`.
loop: :class:`asyncio.AbstractEventLoop`
loop: Optional[:class:`asyncio.AbstractEventLoop`]
The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`.
@ -738,17 +746,15 @@ def loop(
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters.
"""
def decorator(func: LF) -> Loop[LF]:
return Loop[LF](
func,
seconds=seconds,
minutes=minutes,
hours=hours,
count=count,
time=time,
reconnect=reconnect,
loop=loop,
)
def decorator(func: Union[Callable[Concatenate[Type[C], P], _coro[T]], Callable[P, _coro[T]]]) -> Loop[C, P, T]:
kwargs = {
'seconds': seconds,
'minutes': minutes,
'hours': hours,
'count': count,
'time': time,
'reconnect': reconnect,
'loop': loop,
}
return Loop[C, P, T](func, **kwargs)
return decorator

View File

@ -28,7 +28,9 @@ from typing import Optional, TYPE_CHECKING, Union
import os
import io
__all__ = ("File",)
__all__ = (
'File',
)
class File:
@ -62,7 +64,7 @@ class File:
Whether the attachment is a spoiler.
"""
__slots__ = ("fp", "filename", "spoiler", "_original_pos", "_owner", "_closer")
__slots__ = ('fp', 'filename', 'spoiler', '_original_pos', '_owner', '_closer')
if TYPE_CHECKING:
fp: io.BufferedIOBase
@ -78,12 +80,12 @@ class File:
):
if isinstance(fp, io.IOBase):
if not (fp.seekable() and fp.readable()):
raise ValueError(f"File buffer {fp!r} must be seekable and readable")
raise ValueError(f'File buffer {fp!r} must be seekable and readable')
self.fp = fp
self._original_pos = fp.tell()
self._owner = False
else:
self.fp = open(fp, "rb")
self.fp = open(fp, 'rb')
self._original_pos = 0
self._owner = True
@ -98,14 +100,14 @@ class File:
if isinstance(fp, str):
_, self.filename = os.path.split(fp)
else:
self.filename = getattr(fp, "name", None)
self.filename = getattr(fp, 'name', None)
else:
self.filename = filename
if spoiler and self.filename is not None and not self.filename.startswith("SPOILER_"):
self.filename = "SPOILER_" + self.filename
if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'):
self.filename = 'SPOILER_' + self.filename
self.spoiler = spoiler or (self.filename is not None and self.filename.startswith("SPOILER_"))
self.spoiler = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_'))
def reset(self, *, seek: Union[int, bool] = True) -> None:
# The `seek` parameter is needed because

View File

@ -29,19 +29,19 @@ from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optio
from .enums import UserFlags
__all__ = (
"SystemChannelFlags",
"MessageFlags",
"PublicUserFlags",
"Intents",
"MemberCacheFlags",
"ApplicationFlags",
'SystemChannelFlags',
'MessageFlags',
'PublicUserFlags',
'Intents',
'MemberCacheFlags',
'ApplicationFlags',
)
FV = TypeVar("FV", bound="flag_value")
BF = TypeVar("BF", bound="BaseFlags")
FV = TypeVar('FV', bound='flag_value')
BF = TypeVar('BF', bound='BaseFlags')
class flag_value:
class flag_value(Generic[BF]):
def __init__(self, func: Callable[[Any], int]):
self.flag = func(None)
self.__doc__ = func.__doc__
@ -63,7 +63,7 @@ class flag_value:
instance._set_flag(self.flag, value)
def __repr__(self):
return f"<flag_value flag={self.flag!r}>"
return f'<flag_value flag={self.flag!r}>'
class alias_flag_value(flag_value):
@ -98,13 +98,13 @@ class BaseFlags:
value: int
__slots__ = ("value",)
__slots__ = ('value',)
def __init__(self, **kwargs: bool):
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f"{key!r} is not a valid flag name.")
raise TypeError(f'{key!r} is not a valid flag name.')
setattr(self, key, value)
@classmethod
@ -123,7 +123,7 @@ class BaseFlags:
return hash(self.value)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} value={self.value}>"
return f'<{self.__class__.__name__} value={self.value}>'
def __iter__(self) -> Iterator[Tuple[str, bool]]:
for name, value in self.__class__.__dict__.items():
@ -142,7 +142,7 @@ class BaseFlags:
elif toggle is False:
self.value &= ~o
else:
raise TypeError(f"Value to set for {self.__class__.__name__} must be a bool.")
raise TypeError(f'Value to set for {self.__class__.__name__} must be a bool.')
@fill_with_flags(inverted=True)
@ -196,7 +196,7 @@ class SystemChannelFlags(BaseFlags):
elif toggle is False:
self.value |= o
else:
raise TypeError("Value to set for SystemChannelFlags must be a bool.")
raise TypeError('Value to set for SystemChannelFlags must be a bool.')
@flag_value
def join_notifications(self):
@ -205,7 +205,7 @@ class SystemChannelFlags(BaseFlags):
@flag_value
def premium_subscriptions(self):
""":class:`bool`: Returns ``True`` if the system channel is used for "Nitro boosting" notifications."""
""":class:`bool`: Returns ``True`` if the system channel is used for Nitro boosting notifications."""
return 2
@flag_value
@ -287,15 +287,6 @@ class MessageFlags(BaseFlags):
"""
return 32
@flag_value
def ephemeral(self):
""":class:`bool`: Returns ``True`` if the source message is ephemeral.
.. versionadded:: 2.0
"""
return 64
@fill_with_flags()
class PublicUserFlags(BaseFlags):
r"""Wraps up the Discord User Public flags.
@ -461,7 +452,7 @@ class Intents(BaseFlags):
self.value = self.DEFAULT_VALUE
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f"{key!r} is not a valid flag name.")
raise TypeError(f'{key!r} is not a valid flag name.')
setattr(self, key, value)
@classmethod
@ -480,6 +471,16 @@ class Intents(BaseFlags):
self.value = self.DEFAULT_VALUE
return self
@classmethod
def default(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything enabled
except :attr:`presences` and :attr:`members`.
"""
self = cls.all()
self.presences = False
self.members = False
return self
@flag_value
def guilds(self):
""":class:`bool`: Whether guild related events are enabled.
@ -514,7 +515,7 @@ class Intents(BaseFlags):
- :func:`on_member_join`
- :func:`on_member_remove`
- :func:`on_member_update`
- :func:`on_member_update` (nickname, roles)
- :func:`on_user_update`
This also corresponds to the following attributes and classes in terms of cache:
@ -556,34 +557,18 @@ class Intents(BaseFlags):
@flag_value
def emojis(self):
""":class:`bool`: Alias of :attr:`.emojis_and_stickers`.
.. versionchanged:: 2.0
Changed to an alias.
"""
return 1 << 3
@alias_flag_value
def emojis_and_stickers(self):
""":class:`bool`: Whether guild emoji and sticker related events are enabled.
.. versionadded:: 2.0
""":class:`bool`: Whether guild emoji related events are enabled.
This corresponds to the following events:
- :func:`on_guild_emojis_update`
- :func:`on_guild_stickers_update`
This also corresponds to the following attributes and classes in terms of cache:
- :class:`Emoji`
- :class:`GuildSticker`
- :meth:`Client.get_emoji`
- :meth:`Client.get_sticker`
- :meth:`Client.emojis`
- :meth:`Client.stickers`
- :attr:`Guild.emojis`
- :attr:`Guild.stickers`
"""
return 1 << 3
@ -640,10 +625,6 @@ class Intents(BaseFlags):
- :attr:`VoiceChannel.members`
- :attr:`VoiceChannel.voice_states`
- :attr:`Member.voice`
.. note::
This intent is required to connect to voice.
"""
return 1 << 7
@ -653,7 +634,7 @@ class Intents(BaseFlags):
This corresponds to the following events:
- :func:`on_presence_update`
- :func:`on_member_update` (activities, status)
This also corresponds to the following attributes and classes in terms of cache:
@ -907,7 +888,7 @@ class MemberCacheFlags(BaseFlags):
self.value = (1 << bits) - 1
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f"{key!r} is not a valid flag name.")
raise TypeError(f'{key!r} is not a valid flag name.')
setattr(self, key, value)
@classmethod
@ -977,10 +958,10 @@ class MemberCacheFlags(BaseFlags):
def _verify_intents(self, intents: Intents):
if self.voice and not intents.voice_states:
raise ValueError("MemberCacheFlags.voice requires Intents.voice_states")
raise ValueError('MemberCacheFlags.voice requires Intents.voice_states')
if self.joined and not intents.members:
raise ValueError("MemberCacheFlags.joined requires Intents.members")
raise ValueError('MemberCacheFlags.joined requires Intents.members')
@property
def _voice_only(self):

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -25,18 +25,18 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING, overload, Type, Tuple
from .utils import _get_as_snowflake, parse_time, MISSING
from typing import Optional, TYPE_CHECKING, overload, Type, Tuple
from .utils import _get_as_snowflake, get, parse_time
from .user import User
from .errors import InvalidArgument
from .enums import try_enum, ExpireBehaviour
__all__ = (
"IntegrationAccount",
"IntegrationApplication",
"Integration",
"StreamIntegration",
"BotIntegration",
'IntegrationAccount',
'IntegrationApplication',
'Integration',
'StreamIntegration',
'BotIntegration',
)
if TYPE_CHECKING:
@ -59,20 +59,20 @@ class IntegrationAccount:
Attributes
-----------
id: :class:`str`
id: :class:`int`
The account ID.
name: :class:`str`
The account name.
"""
__slots__ = ("id", "name")
__slots__ = ('id', 'name')
def __init__(self, data: IntegrationAccountPayload) -> None:
self.id: str = data["id"]
self.name: str = data["name"]
self.id: Optional[int] = _get_as_snowflake(data, 'id')
self.name: str = data.pop('name')
def __repr__(self) -> str:
return f"<IntegrationAccount id={self.id} name={self.name!r}>"
return f'<IntegrationAccount id={self.id} name={self.name!r}>'
class Integration:
@ -99,14 +99,14 @@ class Integration:
"""
__slots__ = (
"guild",
"id",
"_state",
"type",
"name",
"account",
"user",
"enabled",
'guild',
'id',
'_state',
'type',
'name',
'account',
'user',
'enabled',
)
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
@ -118,16 +118,16 @@ class Integration:
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
def _from_data(self, data: IntegrationPayload) -> None:
self.id: int = int(data["id"])
self.type: IntegrationType = data["type"]
self.name: str = data["name"]
self.account: IntegrationAccount = IntegrationAccount(data["account"])
self.id: int = int(data['id'])
self.type: IntegrationType = data['type']
self.name: str = data['name']
self.account: IntegrationAccount = IntegrationAccount(data['account'])
user = data.get("user")
user = data.get('user')
self.user = User(state=self._state, data=user) if user else None
self.enabled: bool = data["enabled"]
self.enabled: bool = data['enabled']
async def delete(self, *, reason: Optional[str] = None) -> None:
async def delete(self) -> None:
"""|coro|
Deletes the integration.
@ -135,13 +135,6 @@ class Integration:
You must have the :attr:`~Permissions.manage_guild` permission to
do this.
Parameters
-----------
reason: :class:`str`
The reason the integration was deleted. Shows up on the audit log.
.. versionadded:: 2.0
Raises
-------
Forbidden
@ -149,7 +142,7 @@ class Integration:
HTTPException
Deleting the integration failed.
"""
await self._state.http.delete_integration(self.guild.id, self.id, reason=reason)
await self._state.http.delete_integration(self.guild.id, self.id)
class StreamIntegration(Integration):
@ -186,44 +179,48 @@ class StreamIntegration(Integration):
"""
__slots__ = (
"revoked",
"expire_behaviour",
"expire_grace_period",
"synced_at",
"_role_id",
"syncing",
"enable_emoticons",
"subscriber_count",
'revoked',
'expire_behaviour',
'expire_behavior',
'expire_grace_period',
'synced_at',
'_role_id',
'syncing',
'enable_emoticons',
'subscriber_count',
)
def _from_data(self, data: StreamIntegrationPayload) -> None:
super()._from_data(data)
self.revoked: bool = data["revoked"]
self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data["expire_behavior"])
self.expire_grace_period: int = data["expire_grace_period"]
self.synced_at: datetime.datetime = parse_time(data["synced_at"])
self._role_id: Optional[int] = _get_as_snowflake(data, "role_id")
self.syncing: bool = data["syncing"]
self.enable_emoticons: bool = data["enable_emoticons"]
self.subscriber_count: int = data["subscriber_count"]
@property
def expire_behavior(self) -> ExpireBehaviour:
""":class:`ExpireBehaviour`: An alias for :attr:`expire_behaviour`."""
return self.expire_behaviour
self.revoked: bool = data['revoked']
self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data['expire_behavior'])
self.expire_grace_period: int = data['expire_grace_period']
self.synced_at: datetime.datetime = parse_time(data['synced_at'])
self._role_id: int = int(data['role_id'])
self.syncing: bool = data['syncing']
self.enable_emoticons: bool = data['enable_emoticons']
self.subscriber_count: int = data['subscriber_count']
@property
def role(self) -> Optional[Role]:
"""Optional[:class:`Role`] The role which the integration uses for subscribers."""
return self.guild.get_role(self._role_id) # type: ignore
return self.guild.get_role(self._role_id)
@overload
async def edit(
self,
*,
expire_behaviour: ExpireBehaviour = MISSING,
expire_grace_period: int = MISSING,
enable_emoticons: bool = MISSING,
expire_behaviour: Optional[ExpireBehaviour] = ...,
expire_grace_period: Optional[int] = ...,
enable_emoticons: Optional[bool] = ...,
) -> None:
...
@overload
async def edit(self, **fields) -> None:
...
async def edit(self, **fields) -> None:
"""|coro|
Edits the integration.
@ -249,23 +246,35 @@ class StreamIntegration(Integration):
InvalidArgument
``expire_behaviour`` did not receive a :class:`ExpireBehaviour`.
"""
payload: Dict[str, Any] = {}
if expire_behaviour is not MISSING:
if not isinstance(expire_behaviour, ExpireBehaviour):
raise InvalidArgument("expire_behaviour field must be of type ExpireBehaviour")
try:
expire_behaviour = fields['expire_behaviour']
except KeyError:
expire_behaviour = fields.get('expire_behavior', self.expire_behaviour)
payload["expire_behavior"] = expire_behaviour.value
if not isinstance(expire_behaviour, ExpireBehaviour):
raise InvalidArgument('expire_behaviour field must be of type ExpireBehaviour')
if expire_grace_period is not MISSING:
payload["expire_grace_period"] = expire_grace_period
expire_grace_period = fields.get('expire_grace_period', self.expire_grace_period)
if enable_emoticons is not MISSING:
payload["enable_emoticons"] = enable_emoticons
payload = {
'expire_behavior': expire_behaviour.value,
'expire_grace_period': expire_grace_period,
}
try:
enable_emoticons = fields['enable_emoticons']
except KeyError:
enable_emoticons = self.enable_emoticons
else:
payload['enable_emoticons'] = enable_emoticons
# This endpoint is undocumented.
# Unsure if it returns the data or not as a result
await self._state.http.edit_integration(self.guild.id, self.id, **payload)
self.expire_behaviour = expire_behaviour
self.expire_behavior = self.expire_behaviour
self.expire_grace_period = expire_grace_period
self.enable_emoticons = enable_emoticons
async def sync(self) -> None:
"""|coro|
@ -307,21 +316,21 @@ class IntegrationApplication:
"""
__slots__ = (
"id",
"name",
"icon",
"description",
"summary",
"user",
'id',
'name',
'icon',
'description',
'summary',
'user',
)
def __init__(self, *, data: IntegrationApplicationPayload, state):
self.id: int = int(data["id"])
self.name: str = data["name"]
self.icon: Optional[str] = data["icon"]
self.description: str = data["description"]
self.summary: str = data["summary"]
user = data.get("bot")
self.id: int = int(data['id'])
self.name: str = data['name']
self.icon: Optional[str] = data['icon']
self.description: str = data['description']
self.summary: str = data['summary']
user = data.get('bot')
self.user: Optional[User] = User(state=state, data=user) if user else None
@ -350,17 +359,17 @@ class BotIntegration(Integration):
The application tied to this integration.
"""
__slots__ = ("application",)
__slots__ = ('application',)
def _from_data(self, data: BotIntegrationPayload) -> None:
super()._from_data(data)
self.application = IntegrationApplication(data=data["application"], state=self._state)
self.application = IntegrationApplication(data=data['application'], state=self._state)
def _integration_factory(value: str) -> Tuple[Type[Integration], str]:
if value == "discord":
if value == 'discord':
return BotIntegration, value
elif value in ("twitch", "youtube"):
elif value in ('twitch', 'youtube'):
return StreamIntegration, value
else:
return Integration, value

View File

@ -25,46 +25,33 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from discord.types.interactions import InteractionResponse
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union
import asyncio
from . import utils
from .enums import try_enum, InteractionType, InteractionResponseType
from .errors import InteractionResponded, HTTPException, ClientException
from .channel import PartialMessageable, ChannelType
from .user import User
from .member import Member
from .message import Message, Attachment
from .object import Object
from .permissions import Permissions
from .webhook.async_ import async_context, Webhook, handle_message_parameters
from .webhook.async_ import async_context, Webhook
__all__ = (
"Interaction",
"InteractionMessage",
"InteractionResponse",
'Interaction',
'InteractionResponse',
)
if TYPE_CHECKING:
from datetime import datetime
from .types.interactions import (
Interaction as InteractionPayload,
ApplicationCommandOptionChoice,
InteractionData,
)
from .guild import Guild
from .abc import GuildChannel
from .state import ConnectionState
from .file import File
from .mentions import AllowedMentions
from aiohttp import ClientSession
from .embeds import Embed
from .ui.view import View
from .channel import TextChannel, CategoryChannel, StoreChannel, PartialMessageable
from .threads import Thread
InteractionChannel = Union[TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable]
MISSING: Any = utils.MISSING
@ -73,7 +60,8 @@ class Interaction:
"""Represents a Discord interaction.
An interaction happens when a user does an action that needs to
be notified. Current examples are slash commands and components.
be notified. Current examples are slash commands but future examples
include forms and buttons.
.. versionadded:: 2.0
@ -96,68 +84,58 @@ class Interaction:
token: :class:`str`
The token to continue the interaction. These are valid
for 15 minutes.
data: :class:`dict`
The raw interaction data.
"""
__slots__: Tuple[str, ...] = (
"id",
"type",
"guild_id",
"channel_id",
"data",
"application_id",
"message",
"user",
"token",
"version",
"_permissions",
"_state",
"_session",
"_original_message",
"_cs_response",
"_cs_followup",
"_cs_channel",
'id',
'type',
'guild_id',
'channel_id',
'data',
'application_id',
'message',
'user',
'token',
'version',
'_state',
'_session',
'_cs_response',
'_cs_followup',
)
def __init__(self, *, data: InteractionPayload, state: ConnectionState):
self._state: ConnectionState = state
self._state = state
self._session: ClientSession = state.http._HTTPClient__session
self._original_message: Optional[InteractionMessage] = None
self._from_data(data)
def _from_data(self, data: InteractionPayload):
self.id: int = int(data["id"])
self.type: InteractionType = try_enum(InteractionType, data["type"])
self.data: Optional[InteractionData] = data.get("data")
self.token: str = data["token"]
self.version: int = data["version"]
self.channel_id: Optional[int] = utils._get_as_snowflake(data, "channel_id")
self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id")
self.application_id: int = int(data["application_id"])
self.id = int(data['id'])
self.type = try_enum(InteractionType, data['type'])
self.data = data.get('data')
self.token = data['token']
self.version = data['version']
self.channel_id = utils._get_as_snowflake(data, 'channel_id')
self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.application_id = utils._get_as_snowflake(data, 'application_id')
self.message: Optional[Message]
channel = self.channel or Object(id=self.channel_id)
try:
self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore
self.message = Message(state=self._state, channel=channel, data=data['message'])
except KeyError:
self.message = None
self.user: Optional[Union[User, Member]] = None
self._permissions: int = 0
# TODO: there's a potential data loss here
if self.guild_id:
guild = self.guild or Object(id=self.guild_id)
try:
member = data["member"] # type: ignore
self.user = Member(state=self._state, guild=guild, data=data['member'])
except KeyError:
pass
else:
self.user = Member(state=self._state, guild=guild, data=member) # type: ignore
self._permissions = int(member.get("permissions", 0))
else:
try:
self.user = User(state=self._state, data=data["user"])
self.user = User(state=self._state, data=data['user'])
except KeyError:
pass
@ -166,200 +144,31 @@ class Interaction:
"""Optional[:class:`Guild`]: The guild the interaction was sent from."""
return self._state and self._state._get_guild(self.guild_id)
@utils.cached_slot_property("_cs_channel")
def channel(self) -> Optional[InteractionChannel]:
"""Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from.
@property
def channel(self) -> Optional[GuildChannel]:
"""Optional[:class:`abc.GuildChannel`]: The channel the interaction was sent from.
Note that due to a Discord limitation, DM channels are not resolved since there is
no data to complete them. These are :class:`PartialMessageable` instead.
no data to complete them.
"""
guild = self.guild
channel = guild and guild._resolve_channel(self.channel_id)
if channel is None:
if self.channel_id is not None:
type = ChannelType.text if self.guild_id is not None else ChannelType.private
return PartialMessageable(state=self._state, id=self.channel_id, type=type)
return None
return channel # type: ignore
return guild and guild.get_channel(self.channel_id)
@property
def permissions(self) -> Permissions:
""":class:`Permissions`: The resolved permissions of the member in the channel, including overwrites.
In a non-guild context where this doesn't apply, an empty permissions object is returned.
"""
return Permissions(self._permissions)
@utils.cached_slot_property("_cs_response")
@utils.cached_slot_property('_cs_response')
def response(self) -> InteractionResponse:
""":class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction.
A response can only be done once. If secondary messages need to be sent, consider using :attr:`followup`
instead.
"""
""":class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction."""
return InteractionResponse(self)
@utils.cached_slot_property("_cs_followup")
@utils.cached_slot_property('_cs_followup')
def followup(self) -> Webhook:
""":class:`Webhook`: Returns the follow up webhook for follow up interactions."""
payload = {
"id": self.application_id,
"type": 3,
"token": self.token,
'id': self.application_id,
'type': 3,
'token': self.token,
}
return Webhook.from_state(data=payload, state=self._state)
async def original_message(self) -> InteractionMessage:
"""|coro|
Fetches the original interaction response message associated with the interaction.
If the interaction response was :meth:`InteractionResponse.send_message` then this would
return the message that was sent using that response. Otherwise, this would return
the message that triggered the interaction.
Repeated calls to this will return a cached value.
Raises
-------
HTTPException
Fetching the original response message failed.
ClientException
The channel for the message could not be resolved.
Returns
--------
InteractionMessage
The original interaction response message.
"""
if self._original_message is not None:
return self._original_message
# TODO: fix later to not raise?
channel = self.channel
if channel is None:
raise ClientException("Channel for message could not be resolved")
adapter = async_context.get()
data = await adapter.get_original_interaction_response(
application_id=self.application_id,
token=self.token,
session=self._session,
)
state = _InteractionMessageState(self, self._state)
message = InteractionMessage(state=state, channel=channel, data=data) # type: ignore
self._original_message = message
return message
async def edit_original_message(
self,
*,
content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING,
embed: Optional[Embed] = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
) -> InteractionMessage:
"""|coro|
Edits the original interaction response message.
This is a lower level interface to :meth:`InteractionMessage.edit` in case
you do not want to fetch the message and save an HTTP request.
This method is also the only way to edit the original message if
the message sent was ephemeral.
Parameters
------------
content: Optional[:class:`str`]
The content to edit the message with or ``None`` to clear it.
embeds: List[:class:`Embed`]
A list of embeds to edit the message with.
embed: Optional[:class:`Embed`]
The embed to edit the message with. ``None`` suppresses the embeds.
This should not be mixed with the ``embeds`` parameter.
file: :class:`File`
The file to upload. This cannot be mixed with ``files`` parameter.
files: List[:class:`File`]
A list of files to send with the content. This cannot be mixed with the
``file`` parameter.
allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information.
view: Optional[:class:`~discord.ui.View`]
The updated view to update this message with. If ``None`` is passed then
the view is removed.
Raises
-------
HTTPException
Editing the message failed.
Forbidden
Edited a message that is not yours.
TypeError
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
ValueError
The length of ``embeds`` was invalid.
Returns
--------
:class:`InteractionMessage`
The newly edited message.
"""
previous_mentions: Optional[AllowedMentions] = self._state.allowed_mentions
params = handle_message_parameters(
content=content,
file=file,
files=files,
embed=embed,
embeds=embeds,
view=view,
allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions,
)
adapter = async_context.get()
data = await adapter.edit_original_interaction_response(
self.application_id,
self.token,
session=self._session,
payload=params.payload,
multipart=params.multipart,
files=params.files,
)
# The message channel types should always match
message = InteractionMessage(state=self._state, channel=self.channel, data=data) # type: ignore
if view and not view.is_finished():
self._state.store_view(view, message.id)
return message
async def delete_original_message(self) -> None:
"""|coro|
Deletes the original interaction response message.
This is a lower level interface to :meth:`InteractionMessage.delete` in case
you do not want to fetch the message and save an HTTP request.
Raises
-------
HTTPException
Deleting the message failed.
Forbidden
Deleted a message that is not yours.
"""
adapter = async_context.get()
await adapter.delete_original_interaction_response(
self.application_id,
self.token,
session=self._session,
)
class InteractionResponse:
"""Represents a Discord interaction response.
@ -370,20 +179,13 @@ class InteractionResponse:
"""
__slots__: Tuple[str, ...] = (
"responded_at",
"_parent",
'_responded',
'_parent',
)
def __init__(self, parent: Interaction):
self.responded_at: Optional[datetime] = None
self._parent: Interaction = parent
def is_done(self) -> bool:
""":class:`bool`: Indicates whether an interaction response has been done before.
An interaction can only be responded to once.
"""
return self.responded_at is not None
self._responded: bool = False
async def defer(self, *, ephemeral: bool = False) -> None:
"""|coro|
@ -403,11 +205,9 @@ class InteractionResponse:
-------
HTTPException
Deferring the interaction failed.
InteractionResponded
This interaction has already been responded to before.
"""
if self.is_done():
raise InteractionResponded(self._parent)
if self._responded:
return
defer_type: int = 0
data: Optional[Dict[str, Any]] = None
@ -417,15 +217,14 @@ class InteractionResponse:
elif parent.type is InteractionType.application_command:
defer_type = InteractionResponseType.deferred_channel_message.value
if ephemeral:
data = {"flags": 64}
data = {'flags': 64}
if defer_type:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=defer_type, data=data
)
self.responded_at = utils.utcnow()
self._responded = True
async def pong(self) -> None:
"""|coro|
@ -438,11 +237,9 @@ class InteractionResponse:
-------
HTTPException
Ponging the interaction failed.
InteractionResponded
This interaction has already been responded to before.
"""
if self.is_done():
raise InteractionResponded(self._parent)
if self._responded:
return
parent = self._parent
if parent.type is InteractionType.ping:
@ -450,7 +247,7 @@ class InteractionResponse:
await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value
)
self.responded_at = utils.utcnow()
self._responded = True
async def send_message(
self,
@ -493,35 +290,33 @@ class InteractionResponse:
You specified both ``embed`` and ``embeds``.
ValueError
The length of ``embeds`` was invalid.
InteractionResponded
This interaction has already been responded to before.
"""
if self.is_done():
raise InteractionResponded(self._parent)
if self._responded:
return
payload: Dict[str, Any] = {
"tts": tts,
'tts': tts,
}
if embed is not MISSING and embeds is not MISSING:
raise TypeError("cannot mix embed and embeds keyword arguments")
raise TypeError('cannot mix embed and embeds keyword arguments')
if embed is not MISSING:
embeds = [embed]
if embeds:
if len(embeds) > 10:
raise ValueError("embeds cannot exceed maximum of 10 elements")
payload["embeds"] = [e.to_dict() for e in embeds]
raise ValueError('embeds cannot exceed maximum of 10 elements')
payload['embeds'] = [e.to_dict() for e in embeds]
if content is not None:
payload["content"] = str(content)
payload['content'] = str(content)
if ephemeral:
payload["flags"] = 64
payload['flags'] = 64
if view is not MISSING:
payload["components"] = view.to_components()
payload['components'] = view.to_components()
parent = self._parent
adapter = async_context.get()
@ -539,7 +334,7 @@ class InteractionResponse:
self._parent._state.store_view(view)
self.responded_at = utils.utcnow()
self._responded = True
async def edit_message(
self,
@ -577,11 +372,9 @@ class InteractionResponse:
Editing the message failed.
TypeError
You specified both ``embed`` and ``embeds``.
InteractionResponded
This interaction has already been responded to before.
"""
if self.is_done():
raise InteractionResponded(self._parent)
if self._responded:
return
parent = self._parent
msg = parent.message
@ -590,15 +383,16 @@ class InteractionResponse:
if parent.type is not InteractionType.component:
return
# TODO: embeds: List[Embed]?
payload = {}
if content is not MISSING:
if content is None:
payload["content"] = None
payload['content'] = None
else:
payload["content"] = str(content)
payload['content'] = str(content)
if embed is not MISSING and embeds is not MISSING:
raise TypeError("cannot mix both embed and embeds keyword arguments")
raise TypeError('cannot mix both embed and embeds keyword arguments')
if embed is not MISSING:
if embed is None:
@ -607,17 +401,17 @@ class InteractionResponse:
embeds = [embed]
if embeds is not MISSING:
payload["embeds"] = [e.to_dict() for e in embeds]
payload['embeds'] = [e.to_dict() for e in embeds]
if attachments is not MISSING:
payload["attachments"] = [a.to_dict() for a in attachments]
payload['attachments'] = [a.to_dict() for a in attachments]
if view is not MISSING:
state.prevent_view_updates_for(message_id)
if view is None:
payload["components"] = []
payload['components'] = []
else:
payload["components"] = view.to_components()
payload['components'] = view.to_components()
adapter = async_context.get()
await adapter.create_interaction_response(
@ -631,176 +425,4 @@ class InteractionResponse:
if view and not view.is_finished():
state.store_view(view, message_id)
self.responded_at = utils.utcnow()
async def autocomplete_result(self, choices: List[ApplicationCommandOptionChoice]):
"""|coro|
Responds to this autocomplete interaction with the resulting choices.
This should rarely be used.
Parameters
-----------
choices: List[Dict[:class:`str`, :class:`str`]]
The choices to be shown in the autocomplete UI of the user.
Must be a list of dictionaries with the ``name`` and ``value`` keys.
Raises
-------
HTTPException
Responding to the interaction failed.
InteractionResponded
This interaction has already been responded to before.
"""
if self.is_done():
raise InteractionResponded(self._parent)
parent = self._parent
if parent.type is not InteractionType.application_command_autocomplete:
return
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.application_command_autocomplete_result.value,
data={"choices": choices},
)
self.responded_at = utils.utcnow()
class _InteractionMessageState:
__slots__ = ("_parent", "_interaction")
def __init__(self, interaction: Interaction, parent: ConnectionState):
self._interaction: Interaction = interaction
self._parent: ConnectionState = parent
def _get_guild(self, guild_id):
return self._parent._get_guild(guild_id)
def store_user(self, data):
return self._parent.store_user(data)
def create_user(self, data):
return self._parent.create_user(data)
@property
def http(self):
return self._parent.http
def __getattr__(self, attr):
return getattr(self._parent, attr)
class InteractionMessage(Message):
"""Represents the original interaction response message.
This allows you to edit or delete the message associated with
the interaction response. To retrieve this object see :meth:`Interaction.original_message`.
This inherits from :class:`discord.Message` with changes to
:meth:`edit` and :meth:`delete` to work.
.. versionadded:: 2.0
"""
__slots__ = ()
_state: _InteractionMessageState
async def edit(
self,
content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING,
embed: Optional[Embed] = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
) -> InteractionMessage:
"""|coro|
Edits the message.
Parameters
------------
content: Optional[:class:`str`]
The content to edit the message with or ``None`` to clear it.
embeds: List[:class:`Embed`]
A list of embeds to edit the message with.
embed: Optional[:class:`Embed`]
The embed to edit the message with. ``None`` suppresses the embeds.
This should not be mixed with the ``embeds`` parameter.
file: :class:`File`
The file to upload. This cannot be mixed with ``files`` parameter.
files: List[:class:`File`]
A list of files to send with the content. This cannot be mixed with the
``file`` parameter.
allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information.
view: Optional[:class:`~discord.ui.View`]
The updated view to update this message with. If ``None`` is passed then
the view is removed.
Raises
-------
HTTPException
Editing the message failed.
Forbidden
Edited a message that is not yours.
TypeError
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
ValueError
The length of ``embeds`` was invalid.
Returns
---------
:class:`InteractionMessage`
The newly edited message.
"""
return await self._state._interaction.edit_original_message(
content=content,
embeds=embeds,
embed=embed,
file=file,
files=files,
view=view,
allowed_mentions=allowed_mentions,
)
async def delete(self, *, delay: Optional[float] = None) -> None:
"""|coro|
Deletes the message.
Parameters
-----------
delay: Optional[:class:`float`]
If provided, the number of seconds to wait before deleting the message.
The waiting is done in the background and deletion failures are ignored.
Raises
------
Forbidden
You do not have proper permissions to delete the message.
NotFound
The message was deleted already.
HTTPException
Deleting the message failed.
"""
if delay is not None:
async def inner_call(delay: float = delay):
await asyncio.sleep(delay)
try:
await self._state._interaction.delete_original_message()
except HTTPException:
pass
asyncio.create_task(inner_call())
else:
await self._state._interaction.delete_original_message()
self._responded = True

View File

@ -33,9 +33,9 @@ from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum
from .appinfo import PartialAppInfo
__all__ = (
"PartialInviteChannel",
"PartialInviteGuild",
"Invite",
'PartialInviteChannel',
'PartialInviteGuild',
'Invite',
)
if TYPE_CHECKING:
@ -52,8 +52,8 @@ if TYPE_CHECKING:
from .abc import GuildChannel
from .user import User
InviteGuildType = Union[Guild, "PartialInviteGuild", Object]
InviteChannelType = Union[GuildChannel, "PartialInviteChannel", Object]
InviteGuildType = Union[Guild, 'PartialInviteGuild', Object]
InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object]
import datetime
@ -92,23 +92,23 @@ class PartialInviteChannel:
The partial channel's type.
"""
__slots__ = ("id", "name", "type")
__slots__ = ('id', 'name', 'type')
def __init__(self, data: InviteChannelPayload):
self.id: int = int(data["id"])
self.name: str = data["name"]
self.type: ChannelType = try_enum(ChannelType, data["type"])
self.id: int = int(data['id'])
self.name: str = data['name']
self.type: ChannelType = try_enum(ChannelType, data['type'])
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return f"<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>"
return f'<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>'
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f"<#{self.id}>"
return f'<#{self.id}>'
@property
def created_at(self) -> datetime.datetime:
@ -154,26 +154,26 @@ class PartialInviteGuild:
The partial guild's description.
"""
__slots__ = ("_state", "features", "_icon", "_banner", "id", "name", "_splash", "verification_level", "description")
__slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description')
def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int):
self._state: ConnectionState = state
self.id: int = id
self.name: str = data["name"]
self.features: List[str] = data.get("features", [])
self._icon: Optional[str] = data.get("icon")
self._banner: Optional[str] = data.get("banner")
self._splash: Optional[str] = data.get("splash")
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get("verification_level"))
self.description: Optional[str] = data.get("description")
self.name: str = data['name']
self.features: List[str] = data.get('features', [])
self._icon: Optional[str] = data.get('icon')
self._banner: Optional[str] = data.get('banner')
self._splash: Optional[str] = data.get('splash')
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level'))
self.description: Optional[str] = data.get('description')
def __str__(self) -> str:
return self.name
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} "
f"description={self.description!r}>"
f'<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} '
f'description={self.description!r}>'
)
@property
@ -193,17 +193,17 @@ class PartialInviteGuild:
"""Optional[:class:`Asset`]: Returns the guild's banner asset, if available."""
if self._banner is None:
return None
return Asset._from_guild_image(self._state, self.id, self._banner, path="banners")
return Asset._from_guild_image(self._state, self.id, self._banner, path='banners')
@property
def splash(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available."""
if self._splash is None:
return None
return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes")
return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes')
I = TypeVar("I", bound="Invite")
I = TypeVar('I', bound='Invite')
class Invite(Hashable):
@ -230,7 +230,6 @@ class Invite(Hashable):
Returns the invite URL.
The following table illustrates what methods will obtain the attributes:
+------------------------------------+------------------------------------------------------------+
@ -258,7 +257,7 @@ class Invite(Hashable):
Attributes
-----------
max_age: :class:`int`
How long before the invite expires in seconds.
How long the before the invite expires in seconds.
A value of ``0`` indicates that it doesn't expire.
code: :class:`str`
The URL fragment used for the invite.
@ -308,26 +307,26 @@ class Invite(Hashable):
"""
__slots__ = (
"max_age",
"code",
"guild",
"revoked",
"created_at",
"uses",
"temporary",
"max_uses",
"inviter",
"channel",
"target_user",
"target_type",
"_state",
"approximate_member_count",
"approximate_presence_count",
"target_application",
"expires_at",
'max_age',
'code',
'guild',
'revoked',
'created_at',
'uses',
'temporary',
'max_uses',
'inviter',
'channel',
'target_user',
'target_type',
'_state',
'approximate_member_count',
'approximate_presence_count',
'target_application',
'expires_at',
)
BASE = "https://discord.gg"
BASE = 'https://discord.gg'
def __init__(
self,
@ -338,33 +337,31 @@ class Invite(Hashable):
channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None,
):
self._state: ConnectionState = state
self.max_age: Optional[int] = data.get("max_age")
self.code: str = data["code"]
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get("guild"), guild)
self.revoked: Optional[bool] = data.get("revoked")
self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at"))
self.temporary: Optional[bool] = data.get("temporary")
self.uses: Optional[int] = data.get("uses")
self.max_uses: Optional[int] = data.get("max_uses")
self.approximate_presence_count: Optional[int] = data.get("approximate_presence_count")
self.approximate_member_count: Optional[int] = data.get("approximate_member_count")
self.max_age: Optional[int] = data.get('max_age')
self.code: str = data['code']
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild)
self.revoked: Optional[bool] = data.get('revoked')
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
self.temporary: Optional[bool] = data.get('temporary')
self.uses: Optional[int] = data.get('uses')
self.max_uses: Optional[int] = data.get('max_uses')
self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count')
self.approximate_member_count: Optional[int] = data.get('approximate_member_count')
expires_at = data.get("expires_at", None)
expires_at = data.get('expires_at', None)
self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None
inviter_data = data.get("inviter")
self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data)
inviter_data = data.get('inviter')
self.inviter: Optional[User] = None if inviter_data is None else self._state.store_user(inviter_data)
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get("channel"), channel)
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel)
target_user_data = data.get("target_user")
self.target_user: Optional[User] = (
None if target_user_data is None else self._state.create_user(target_user_data)
)
target_user_data = data.get('target_user')
self.target_user: Optional[User] = None if target_user_data is None else self._state.store_user(target_user_data)
self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0))
application = data.get("target_application")
application = data.get('target_application')
self.target_application: Optional[PartialAppInfo] = (
PartialAppInfo(data=application, state=state) if application else None
)
@ -373,12 +370,12 @@ class Invite(Hashable):
def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I:
guild: Optional[Union[Guild, PartialInviteGuild]]
try:
guild_data = data["guild"]
guild_data = data['guild']
except KeyError:
# If we're here, then this is a group DM
guild = None
else:
guild_id = int(guild_data["id"])
guild_id = int(guild_data['id'])
guild = state._get_guild(guild_id)
if guild is None:
# If it's not cached, then it has to be a partial guild
@ -386,7 +383,7 @@ class Invite(Hashable):
# As far as I know, invites always need a channel
# So this should never raise.
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data["channel"])
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel'])
if guild is not None and not isinstance(guild, PartialInviteGuild):
# Upgrade the partial data if applicable
channel = guild.get_channel(channel.id) or channel
@ -395,9 +392,9 @@ class Invite(Hashable):
@classmethod
def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I:
guild_id: Optional[int] = _get_as_snowflake(data, "guild_id")
guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id')
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id)
channel_id = int(data["channel_id"])
channel_id = int(data['channel_id'])
if guild is not None:
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore
else:
@ -417,7 +414,7 @@ class Invite(Hashable):
if data is None:
return None
guild_id = int(data["id"])
guild_id = int(data['id'])
return PartialInviteGuild(self._state, data, guild_id)
def _resolve_channel(
@ -436,14 +433,11 @@ class Invite(Hashable):
def __str__(self) -> str:
return self.url
def __int__(self) -> int:
return 0 # To keep the object compatible with the hashable abc.
def __repr__(self) -> str:
return (
f"<Invite code={self.code!r} guild={self.guild!r} "
f"online={self.approximate_presence_count} "
f"members={self.approximate_member_count}>"
f'<Invite code={self.code!r} guild={self.guild!r} '
f'online={self.approximate_presence_count} '
f'members={self.approximate_member_count}>'
)
def __hash__(self) -> int:
@ -457,7 +451,7 @@ class Invite(Hashable):
@property
def url(self) -> str:
""":class:`str`: A property that retrieves the invite URL."""
return self.BASE + "/" + self.code
return self.BASE + '/' + self.code
async def delete(self, *, reason: Optional[str] = None):
"""|coro|

View File

@ -34,11 +34,11 @@ from .object import Object
from .audit_logs import AuditLogEntry
__all__ = (
"ReactionIterator",
"HistoryIterator",
"AuditLogIterator",
"GuildIterator",
"MemberIterator",
'ReactionIterator',
'HistoryIterator',
'AuditLogIterator',
'GuildIterator',
'MemberIterator',
)
if TYPE_CHECKING:
@ -67,8 +67,8 @@ if TYPE_CHECKING:
from .threads import Thread
from .abc import Snowflake
T = TypeVar("T")
OT = TypeVar("OT")
T = TypeVar('T')
OT = TypeVar('OT')
_Func = Callable[[T], Union[OT, Awaitable[OT]]]
OLDEST_OBJECT = Object(id=0)
@ -83,7 +83,7 @@ class _AsyncIterator(AsyncIterator[T]):
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
def predicate(elem: T):
for attr, val in attrs.items():
nested = attr.split("__")
nested = attr.split('__')
obj = elem
for attribute in nested:
obj = getattr(obj, attribute)
@ -107,7 +107,7 @@ class _AsyncIterator(AsyncIterator[T]):
def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]:
if max_size <= 0:
raise ValueError("async iterator chunk sizes must be greater than 0.")
raise ValueError('async iterator chunk sizes must be greater than 0.')
return _ChunkedAsyncIterator(self, max_size)
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
@ -182,7 +182,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
return item
class ReactionIterator(_AsyncIterator[Union["User", "Member"]]):
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
def __init__(self, message, emoji, limit=100, after=None):
self.message = message
self.limit = limit
@ -218,14 +218,14 @@ class ReactionIterator(_AsyncIterator[Union["User", "Member"]]):
if data:
self.limit -= retrieve
self.after = Object(id=int(data[-1]["id"]))
self.after = Object(id=int(data[-1]['id']))
if self.guild is None or isinstance(self.guild, Object):
for element in reversed(data):
await self.users.put(User(state=self.state, data=element))
else:
for element in reversed(data):
member_id = int(element["id"])
member_id = int(element['id'])
member = self.guild.get_member(member_id)
if member is not None:
await self.users.put(member)
@ -233,7 +233,7 @@ class ReactionIterator(_AsyncIterator[Union["User", "Member"]]):
await self.users.put(User(state=self.state, data=element))
class HistoryIterator(_AsyncIterator["Message"]):
class HistoryIterator(_AsyncIterator['Message']):
"""Iterator for receiving a channel's message history.
The messages endpoint has two behaviours we care about here:
@ -295,7 +295,7 @@ class HistoryIterator(_AsyncIterator["Message"]):
if self.around:
if self.limit is None:
raise ValueError("history does not support around with limit=None")
raise ValueError('history does not support around with limit=None')
if self.limit > 101:
raise ValueError("history max limit 101 when specifying around parameter")
elif self.limit == 101:
@ -303,20 +303,20 @@ class HistoryIterator(_AsyncIterator["Message"]):
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore
if self.before and self.after:
self._filter = lambda m: self.after.id < int(m["id"]) < self.before.id
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
elif self.before:
self._filter = lambda m: int(m["id"]) < self.before.id
self._filter = lambda m: int(m['id']) < self.before.id
elif self.after:
self._filter = lambda m: self.after.id < int(m["id"])
self._filter = lambda m: self.after.id < int(m['id'])
else:
if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore
if self.before:
self._filter = lambda m: int(m["id"]) < self.before.id
self._filter = lambda m: int(m['id']) < self.before.id
else:
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m["id"]) > self.after.id
self._filter = lambda m: int(m['id']) > self.after.id
async def next(self) -> Message:
if self.messages.empty():
@ -337,7 +337,7 @@ class HistoryIterator(_AsyncIterator["Message"]):
return r > 0
async def fill_messages(self):
if not hasattr(self, "channel"):
if not hasattr(self, 'channel'):
# do the required set up
channel = await self.messageable._get_channel()
self.channel = channel
@ -367,7 +367,7 @@ class HistoryIterator(_AsyncIterator["Message"]):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]["id"]))
self.before = Object(id=int(data[-1]['id']))
return data
async def _retrieve_messages_after_strategy(self, retrieve):
@ -377,7 +377,7 @@ class HistoryIterator(_AsyncIterator["Message"]):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]["id"]))
self.after = Object(id=int(data[0]['id']))
return data
async def _retrieve_messages_around_strategy(self, retrieve):
@ -390,7 +390,7 @@ class HistoryIterator(_AsyncIterator["Message"]):
return []
class AuditLogIterator(_AsyncIterator["AuditLogEntry"]):
class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False))
@ -420,11 +420,11 @@ class AuditLogIterator(_AsyncIterator["AuditLogEntry"]):
if self.reverse:
self._strategy = self._after_strategy
if self.before:
self._filter = lambda m: int(m["id"]) < self.before.id
self._filter = lambda m: int(m['id']) < self.before.id
else:
self._strategy = self._before_strategy
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m["id"]) > self.after.id
self._filter = lambda m: int(m['id']) > self.after.id
async def _before_strategy(self, retrieve):
before = self.before.id if self.before else None
@ -432,24 +432,24 @@ class AuditLogIterator(_AsyncIterator["AuditLogEntry"]):
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before
)
entries = data.get("audit_log_entries", [])
entries = data.get('audit_log_entries', [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(entries[-1]["id"]))
return data.get("users", []), entries
self.before = Object(id=int(entries[-1]['id']))
return data.get('users', []), entries
async def _after_strategy(self, retrieve):
after = self.after.id if self.after else None
data: AuditLogPayload = await self.request(
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after
)
entries = data.get("audit_log_entries", [])
entries = data.get('audit_log_entries', [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(entries[0]["id"]))
return data.get("users", []), entries
self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries
async def next(self) -> AuditLogEntry:
if self.entries.empty():
@ -488,13 +488,13 @@ class AuditLogIterator(_AsyncIterator["AuditLogEntry"]):
for element in data:
# TODO: remove this if statement later
if element["action_type"] is None:
if element['action_type'] is None:
continue
await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild))
class GuildIterator(_AsyncIterator["Guild"]):
class GuildIterator(_AsyncIterator['Guild']):
"""Iterator for receiving the client's guilds.
The guilds endpoint has the same two behaviours as described
@ -543,7 +543,7 @@ class GuildIterator(_AsyncIterator["Guild"]):
if self.before and self.after:
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore
self._filter = lambda m: int(m["id"]) > self.after.id
self._filter = lambda m: int(m['id']) > self.after.id
elif self.after:
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore
else:
@ -595,7 +595,7 @@ class GuildIterator(_AsyncIterator["Guild"]):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]["id"]))
self.before = Object(id=int(data[-1]['id']))
return data
async def _retrieve_guilds_after_strategy(self, retrieve):
@ -605,11 +605,11 @@ class GuildIterator(_AsyncIterator["Guild"]):
if len(data):
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[0]["id"]))
self.after = Object(id=int(data[0]['id']))
return data
class MemberIterator(_AsyncIterator["Member"]):
class MemberIterator(_AsyncIterator['Member']):
def __init__(self, guild, limit=1000, after=None):
if isinstance(after, datetime.datetime):
@ -652,7 +652,7 @@ class MemberIterator(_AsyncIterator["Member"]):
if len(data) < 1000:
self.limit = 0 # terminate loop
self.after = Object(id=int(data[-1]["user"]["id"]))
self.after = Object(id=int(data[-1]['user']['id']))
for element in reversed(data):
await self.members.put(self.create_member(element))
@ -663,7 +663,7 @@ class MemberIterator(_AsyncIterator["Member"]):
return Member(data=data, guild=self.guild, state=self.state)
class ArchivedThreadIterator(_AsyncIterator["Thread"]):
class ArchivedThreadIterator(_AsyncIterator['Thread']):
def __init__(
self,
channel_id: int,
@ -681,7 +681,7 @@ class ArchivedThreadIterator(_AsyncIterator["Thread"]):
self.http = guild._state.http
if joined and not private:
raise ValueError("Cannot iterate over joined public archived threads")
raise ValueError('Cannot iterate over joined public archived threads')
self.before: Optional[str]
if before is None:
@ -721,11 +721,11 @@ class ArchivedThreadIterator(_AsyncIterator["Thread"]):
@staticmethod
def get_archive_timestamp(data: ThreadPayload) -> str:
return data["thread_metadata"]["archive_timestamp"]
return data['thread_metadata']['archive_timestamp']
@staticmethod
def get_thread_id(data: ThreadPayload) -> str:
return data["id"] # type: ignore
return data['id'] # type: ignore
async def fill_queue(self) -> None:
if not self.has_more:
@ -735,11 +735,11 @@ class ArchivedThreadIterator(_AsyncIterator["Thread"]):
data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
# This stuff is obviously WIP because 'members' is always empty
threads: List[ThreadPayload] = data.get("threads", [])
threads: List[ThreadPayload] = data.get('threads', [])
for d in reversed(threads):
self.queue.put_nowait(self.create_thread(d))
self.has_more = data.get("has_more", False)
self.has_more = data.get('has_more', False)
if self.limit is not None:
self.limit -= len(threads)
if self.limit <= 0:
@ -750,5 +750,4 @@ class ArchivedThreadIterator(_AsyncIterator["Thread"]):
def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread
return Thread(guild=self.guild, state=self.guild._state, data=data)
return Thread(guild=self.guild, data=data)

View File

@ -29,46 +29,29 @@ import inspect
import itertools
import sys
from operator import attrgetter
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
from typing import List, Literal, Optional, TYPE_CHECKING, Union, overload
import discord.abc
from . import utils
from .asset import Asset
from .utils import MISSING
from .user import BaseUser, User, _UserTag
from .activity import create_activity, ActivityTypes
from .user import BaseUser, User
from .activity import create_activity
from .permissions import Permissions
from .enums import Status, try_enum
from .colour import Colour
from .object import Object
__all__ = (
"VoiceState",
"Member",
'VoiceState',
'Member',
)
if TYPE_CHECKING:
from .asset import Asset
from .channel import DMChannel, VoiceChannel, StageChannel
from .flags import PublicUserFlags
from .guild import Guild
from .types.activity import PartialPresenceUpdate
from .types.member import (
MemberWithUser as MemberWithUserPayload,
Member as MemberPayload,
UserWithMember as UserWithMemberPayload,
)
from .types.user import User as UserPayload
from .channel import VoiceChannel, StageChannel
from .abc import Snowflake
from .state import ConnectionState
from .message import Message
from .role import Role
from .types.voice import VoiceState as VoiceStatePayload
VocalGuildChannel = Union[VoiceChannel, StageChannel]
class VoiceState:
"""Represents a Discord user's voice state.
@ -112,55 +95,42 @@ class VoiceState:
is not currently in a voice channel.
"""
__slots__ = (
"session_id",
"deaf",
"mute",
"self_mute",
"self_stream",
"self_video",
"self_deaf",
"afk",
"channel",
"requested_to_speak_at",
"suppress",
)
__slots__ = ('session_id', 'deaf', 'mute', 'self_mute',
'self_stream', 'self_video', 'self_deaf', 'afk', 'channel',
'requested_to_speak_at', 'suppress')
def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None):
self.session_id: str = data.get("session_id")
def __init__(self, *, data, channel=None):
self.session_id = data.get('session_id')
self._update(data, channel)
def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]):
self.self_mute: bool = data.get("self_mute", False)
self.self_deaf: bool = data.get("self_deaf", False)
self.self_stream: bool = data.get("self_stream", False)
self.self_video: bool = data.get("self_video", False)
self.afk: bool = data.get("suppress", False)
self.mute: bool = data.get("mute", False)
self.deaf: bool = data.get("deaf", False)
self.suppress: bool = data.get("suppress", False)
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(
data.get("request_to_speak_timestamp")
)
self.channel: Optional[VocalGuildChannel] = channel
def _update(self, data, channel):
self.self_mute = data.get('self_mute', False)
self.self_deaf = data.get('self_deaf', False)
self.self_stream = data.get('self_stream', False)
self.self_video = data.get('self_video', False)
self.afk = data.get('suppress', False)
self.mute = data.get('mute', False)
self.deaf = data.get('deaf', False)
self.suppress = data.get('suppress', False)
self.requested_to_speak_at = utils.parse_time(data.get('request_to_speak_timestamp'))
self.channel = channel
def __repr__(self) -> str:
def __repr__(self):
attrs = [
("self_mute", self.self_mute),
("self_deaf", self.self_deaf),
("self_stream", self.self_stream),
("suppress", self.suppress),
("requested_to_speak_at", self.requested_to_speak_at),
("channel", self.channel),
('self_mute', self.self_mute),
('self_deaf', self.self_deaf),
('self_stream', self.self_stream),
('suppress', self.suppress),
('requested_to_speak_at', self.requested_to_speak_at),
('channel', self.channel)
]
inner = " ".join("%s=%r" % t for t in attrs)
return f"<{self.__class__.__name__} {inner}>"
inner = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {inner}>'
def flatten_user(cls):
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# ignore private/special methods
if attr.startswith("_"):
if attr.startswith('_'):
continue
# don't override what we already have
@ -169,9 +139,9 @@ def flatten_user(cls):
# if it's a slotted attribute or a property, redirect it
# slotted members are implemented as member_descriptors in Type.__dict__
if not hasattr(value, "__annotations__"):
getter = attrgetter("_user." + attr)
setattr(cls, attr, property(getter, doc=f"Equivalent to :attr:`User.{attr}`"))
if not hasattr(value, '__annotations__'):
getter = attrgetter('_user.' + attr)
setattr(cls, attr, property(getter, doc=f'Equivalent to :attr:`User.{attr}`'))
else:
# Technically, this can also use attrgetter
# However I'm not sure how I feel about "functions" returning properties
@ -180,12 +150,9 @@ def flatten_user(cls):
def generate_function(x):
# We want sphinx to properly show coroutine functions as coroutines
if inspect.iscoroutinefunction(value):
async def general(self, *args, **kwargs): # type: ignore
async def general(self, *args, **kwargs):
return await getattr(self._user, x)(*args, **kwargs)
else:
def general(self, *args, **kwargs):
return getattr(self._user, x)(*args, **kwargs)
@ -198,12 +165,10 @@ def flatten_user(cls):
return cls
M = TypeVar("M", bound="Member")
_BaseUser = discord.abc.User
@flatten_user
class Member(discord.abc.Messageable, _UserTag):
class Member(discord.abc.Messageable, _BaseUser):
"""Represents a Discord member to a :class:`Guild`.
This implements a lot of the functionality of :class:`User`.
@ -228,10 +193,6 @@ class Member(discord.abc.Messageable, _UserTag):
Returns the member's name with the discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes
----------
joined_at: Optional[:class:`datetime.datetime`]
@ -256,103 +217,69 @@ class Member(discord.abc.Messageable, _UserTag):
.. versionadded:: 1.6
premium_since: Optional[:class:`datetime.datetime`]
An aware datetime object that specifies the date and time in UTC when the member used their
"Nitro boost" on the guild, if available. This could be ``None``.
Nitro boost on the guild, if available. This could be ``None``.
"""
__slots__ = (
"_roles",
"joined_at",
"premium_since",
"activities",
"guild",
"pending",
"nick",
"_client_status",
"_user",
"_state",
"_avatar",
)
__slots__ = ('_roles', 'joined_at', 'premium_since', '_client_status',
'activities', 'guild', 'pending', 'nick', '_user', '_state')
if TYPE_CHECKING:
name: str
id: int
discriminator: str
bot: bool
system: bool
created_at: datetime.datetime
default_avatar: Asset
avatar: Optional[Asset]
dm_channel: Optional[DMChannel]
create_dm = User.create_dm
mutual_guilds: List[Guild]
public_flags: PublicUserFlags
banner: Optional[Asset]
accent_color: Optional[Colour]
accent_colour: Optional[Colour]
def __init__(self, *, data, guild, state):
self._state = state
self._user = state.store_user(data['user'])
self.guild = guild
self.joined_at = utils.parse_time(data.get('joined_at'))
self.premium_since = utils.parse_time(data.get('premium_since'))
self._update_roles(data)
self._client_status = {
None: 'offline'
}
self.activities = []
self.nick = data.get('nick', None)
self.pending = data.get('pending', False)
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState):
self._state: ConnectionState = state
self._user: User = state.store_user(data["user"])
self.guild: Guild = guild
self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get("joined_at"))
self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get("premium_since"))
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"]))
self._client_status: Dict[Optional[str], str] = {None: "offline"}
self.activities: Tuple[ActivityTypes, ...] = tuple()
self.nick: Optional[str] = data.get("nick", None)
self.pending: bool = data.get("pending", False)
self._avatar: Optional[str] = data.get("avatar")
def __str__(self) -> str:
def __str__(self):
return str(self._user)
def __int__(self) -> int:
return self.id
def __repr__(self):
return f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}' \
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
def __repr__(self) -> str:
return (
f"<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}"
f" bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>"
)
def __eq__(self, other):
return isinstance(other, _BaseUser) and other.id == self.id
def __eq__(self, other: Any) -> bool:
return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool:
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self) -> int:
def __hash__(self):
return hash(self._user)
@classmethod
def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M:
def _from_message(cls, *, message, data):
author = message.author
data["user"] = author._to_minimal_user_json() # type: ignore
return cls(data=data, guild=message.guild, state=message._state) # type: ignore
data['user'] = author._to_minimal_user_json()
return cls(data=data, guild=message.guild, state=message._state)
def _update_from_message(self, data: MemberPayload) -> None:
self.joined_at = utils.parse_time(data.get("joined_at"))
self.premium_since = utils.parse_time(data.get("premium_since"))
self._roles = utils.SnowflakeList(map(int, data["roles"]))
self.nick = data.get("nick", None)
self.pending = data.get("pending", False)
def _update_from_message(self, data):
self.joined_at = utils.parse_time(data.get('joined_at'))
self.premium_since = utils.parse_time(data.get('premium_since'))
self._update_roles(data)
self.nick = data.get('nick', None)
self.pending = data.get('pending', False)
@classmethod
def _try_upgrade(
cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState
) -> Union[User, M]:
def _try_upgrade(cls, *, data, guild, state):
# A User object with a 'member' key
try:
member_data = data.pop("member")
member_data = data.pop('member')
except KeyError:
return state.create_user(data)
return state.store_user(data)
else:
member_data["user"] = data # type: ignore
return cls(data=member_data, guild=guild, state=state) # type: ignore
member_data['user'] = data
return cls(data=member_data, guild=guild, state=state)
@classmethod
def _copy(cls: Type[M], member: M) -> M:
self: M = cls.__new__(cls) # to bypass __init__
def _copy(cls, member):
self = cls.__new__(cls) # to bypass __init__
self._roles = utils.SnowflakeList(member._roles, is_sorted=True)
self.joined_at = member.joined_at
@ -363,7 +290,6 @@ class Member(discord.abc.Messageable, _UserTag):
self.pending = member.pending
self.activities = member.activities
self._state = member._state
self._avatar = member._avatar
# Reference will not be copied unless necessary by PRESENCE_UPDATE
# See below
@ -374,39 +300,42 @@ class Member(discord.abc.Messageable, _UserTag):
ch = await self.create_dm()
return ch
def _update(self, data: MemberPayload) -> None:
def _update_roles(self, data):
self._roles = utils.SnowflakeList(map(int, data['roles']))
def _update(self, data):
# the nickname change is optional,
# if it isn't in the payload then it didn't change
try:
self.nick = data["nick"]
self.nick = data['nick']
except KeyError:
pass
try:
self.pending = data["pending"]
self.pending = data['pending']
except KeyError:
pass
self.premium_since = utils.parse_time(data.get("premium_since"))
self._roles = utils.SnowflakeList(map(int, data["roles"]))
self._avatar = data.get("avatar")
self.premium_since = utils.parse_time(data.get('premium_since'))
self._update_roles(data)
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]:
self.activities = tuple(map(create_activity, data["activities"]))
def _presence_update(self, data, user):
self.activities = tuple(map(create_activity, data['activities']))
self._client_status = {
sys.intern(key): sys.intern(value) for key, value in data.get("client_status", {}).items() # type: ignore
sys.intern(key): sys.intern(value)
for key, value in data.get('client_status', {}).items()
}
self._client_status[None] = sys.intern(data["status"])
self._client_status[None] = sys.intern(data['status'])
if len(user) > 1:
return self._update_inner_user(user)
return None
return False
def _update_inner_user(self, user: UserPayload) -> Optional[Tuple[User, User]]:
def _update_inner_user(self, user):
u = self._user
original = (u.name, u._avatar, u.discriminator, u._public_flags)
# These keys seem to always be available
modified = (user["username"], user["avatar"], user["discriminator"], user.get("public_flags", 0))
modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0))
if original != modified:
to_return = User._copy(self._user)
u.name, u._avatar, u.discriminator, u._public_flags = modified
@ -414,12 +343,12 @@ class Member(discord.abc.Messageable, _UserTag):
return to_return, u
@property
def status(self) -> Status:
def status(self):
""":class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead."""
return try_enum(Status, self._client_status[None])
@property
def raw_status(self) -> str:
def raw_status(self):
""":class:`str`: The member's overall status as a string value.
.. versionadded:: 1.5
@ -427,31 +356,31 @@ class Member(discord.abc.Messageable, _UserTag):
return self._client_status[None]
@status.setter
def status(self, value: Status) -> None:
def status(self, value):
# internal use only
self._client_status[None] = str(value)
@property
def mobile_status(self) -> Status:
def mobile_status(self):
""":class:`Status`: The member's status on a mobile device, if applicable."""
return try_enum(Status, self._client_status.get("mobile", "offline"))
return try_enum(Status, self._client_status.get('mobile', 'offline'))
@property
def desktop_status(self) -> Status:
def desktop_status(self):
""":class:`Status`: The member's status on the desktop client, if applicable."""
return try_enum(Status, self._client_status.get("desktop", "offline"))
return try_enum(Status, self._client_status.get('desktop', 'offline'))
@property
def web_status(self) -> Status:
def web_status(self):
""":class:`Status`: The member's status on the web client, if applicable."""
return try_enum(Status, self._client_status.get("web", "offline"))
return try_enum(Status, self._client_status.get('web', 'offline'))
def is_on_mobile(self) -> bool:
def is_on_mobile(self):
""":class:`bool`: A helper function that determines if a member is active on a mobile device."""
return "mobile" in self._client_status
return 'mobile' in self._client_status
@property
def colour(self) -> Colour:
def colour(self):
""":class:`Colour`: A property that returns a colour denoting the rendered colour
for the member. If the default colour is the one rendered then an instance
of :meth:`Colour.default` is returned.
@ -459,7 +388,7 @@ class Member(discord.abc.Messageable, _UserTag):
There is an alias for this named :attr:`color`.
"""
roles = self.roles[1:] # remove @everyone
roles = self.roles[1:] # remove @everyone
# highest order of the colour is the one that gets rendered.
# if the highest is the default colour then the next one with a colour
@ -470,7 +399,7 @@ class Member(discord.abc.Messageable, _UserTag):
return Colour.default()
@property
def color(self) -> Colour:
def color(self):
""":class:`Colour`: A property that returns a color denoting the rendered color for
the member. If the default color is the one rendered then an instance of :meth:`Colour.default`
is returned.
@ -480,7 +409,7 @@ class Member(discord.abc.Messageable, _UserTag):
return self.colour
@property
def roles(self) -> List[Role]:
def roles(self):
"""List[:class:`Role`]: A :class:`list` of :class:`Role` that the member belongs to. Note
that the first element of this list is always the default '@everyone'
role.
@ -498,14 +427,14 @@ class Member(discord.abc.Messageable, _UserTag):
return result
@property
def mention(self) -> str:
def mention(self):
""":class:`str`: Returns a string that allows you to mention the member."""
if self.nick:
return f"<@!{self._user.id}>"
return f"<@{self._user.id}>"
return f'<@!{self._user.id}>'
return f'<@{self._user.id}>'
@property
def display_name(self) -> str:
def display_name(self):
""":class:`str`: Returns the user's display name.
For regular users this is just their username, but
@ -515,31 +444,8 @@ class Member(discord.abc.Messageable, _UserTag):
return self.nick or self.name
@property
def display_avatar(self) -> Asset:
""":class:`Asset`: Returns the member's display avatar.
For regular members this is just their avatar, but
if they have a guild specific avatar then that
is returned instead.
.. versionadded:: 2.0
"""
return self.guild_avatar or self._user.avatar or self._user.default_avatar
@property
def guild_avatar(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the guild avatar
the member has. If unavailable, ``None`` is returned.
.. versionadded:: 2.0
"""
if self._avatar is None:
return None
return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar)
@property
def activity(self) -> Optional[ActivityTypes]:
"""Optional[Union[:class:`BaseActivity`, :class:`Spotify`]]: Returns the primary
def activity(self):
"""Union[:class:`BaseActivity`, :class:`Spotify`]: Returns the primary
activity the user is currently doing. Could be ``None`` if no activity is being done.
.. note::
@ -555,7 +461,7 @@ class Member(discord.abc.Messageable, _UserTag):
if self.activities:
return self.activities[0]
def mentioned_in(self, message: Message) -> bool:
def mentioned_in(self, message):
"""Checks if the member is mentioned in the specified message.
Parameters
@ -577,7 +483,7 @@ class Member(discord.abc.Messageable, _UserTag):
return any(self._roles.has(role.id) for role in message.role_mentions)
@property
def top_role(self) -> Role:
def top_role(self):
""":class:`Role`: Returns the member's highest role.
This is useful for figuring where a member stands in the role
@ -590,7 +496,7 @@ class Member(discord.abc.Messageable, _UserTag):
return max(guild.get_role(rid) or guild.default_role for rid in self._roles)
@property
def guild_permissions(self) -> Permissions:
def guild_permissions(self):
""":class:`Permissions`: Returns the member's guild permissions.
This only takes into consideration the guild permissions
@ -615,21 +521,29 @@ class Member(discord.abc.Messageable, _UserTag):
return base
@property
def voice(self) -> Optional[VoiceState]:
def voice(self):
"""Optional[:class:`VoiceState`]: Returns the member's current voice state."""
return self.guild._voice_state_for(self._user.id)
@overload
async def ban(
self,
*,
delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1,
reason: Optional[str] = None,
reason: Optional[str] = ...,
delete_message_days: Literal[1, 2, 3, 4, 5, 6, 7] = ...,
) -> None:
...
@overload
async def ban(self) -> None:
...
async def ban(self, **kwargs):
"""|coro|
Bans this member. Equivalent to :meth:`Guild.ban`.
"""
await self.guild.ban(self, reason=reason, delete_message_days=delete_message_days)
await self.guild.ban(self, **kwargs)
async def unban(self, *, reason: Optional[str] = None) -> None:
"""|coro|
@ -645,17 +559,25 @@ class Member(discord.abc.Messageable, _UserTag):
"""
await self.guild.kick(self, reason=reason)
@overload
async def edit(
self,
*,
nick: Optional[str] = MISSING,
mute: bool = MISSING,
deafen: bool = MISSING,
suppress: bool = MISSING,
roles: List[discord.abc.Snowflake] = MISSING,
voice_channel: Optional[VocalGuildChannel] = MISSING,
reason: Optional[str] = None,
) -> Optional[Member]:
reason: Optional[str] = ...,
nick: Optional[str] = None,
mute: bool = ...,
deafen: bool = ...,
suppress: bool = ...,
roles: Optional[List[discord.abc.Snowflake]] = ...,
voice_channel: Optional[VocalGuildChannel] = ...,
) -> None:
...
@overload
async def edit(self) -> None:
...
async def edit(self, *, reason=None, **fields):
"""|coro|
Edits the member's data.
@ -681,9 +603,6 @@ class Member(discord.abc.Messageable, _UserTag):
.. versionchanged:: 1.1
Can now pass ``None`` to ``voice_channel`` to kick a member from voice.
.. versionchanged:: 2.0
The newly member is now optionally returned, if applicable.
Parameters
-----------
nick: Optional[:class:`str`]
@ -697,7 +616,7 @@ class Member(discord.abc.Messageable, _UserTag):
.. versionadded:: 1.7
roles: List[:class:`Role`]
roles: Optional[List[:class:`Role`]]
The member's new list of roles. This *replaces* the roles.
voice_channel: Optional[:class:`VoiceChannel`]
The voice channel to move the member to.
@ -711,58 +630,69 @@ class Member(discord.abc.Messageable, _UserTag):
You do not have the proper permissions to the action requested.
HTTPException
The operation failed.
Returns
--------
Optional[:class:`.Member`]
The newly updated member, if applicable. This is only returned
when certain fields are updated.
"""
http = self._state.http
guild_id = self.guild.id
me = self._state.self_id == self.id
payload: Dict[str, Any] = {}
payload = {}
if nick is not MISSING:
nick = nick or ""
try:
nick = fields['nick']
except KeyError:
# nick not present so...
pass
else:
nick = nick or ''
if me:
await http.change_my_nickname(guild_id, nick, reason=reason)
else:
payload["nick"] = nick
payload['nick'] = nick
if deafen is not MISSING:
payload["deaf"] = deafen
deafen = fields.get('deafen')
if deafen is not None:
payload['deaf'] = deafen
if mute is not MISSING:
payload["mute"] = mute
mute = fields.get('mute')
if mute is not None:
payload['mute'] = mute
if suppress is not MISSING:
suppress = fields.get('suppress')
if suppress is not None:
voice_state_payload = {
"channel_id": self.voice.channel.id,
"suppress": suppress,
'channel_id': self.voice.channel.id,
'suppress': suppress,
}
if suppress or self.bot:
voice_state_payload["request_to_speak_timestamp"] = None
voice_state_payload['request_to_speak_timestamp'] = None
if me:
await http.edit_my_voice_state(guild_id, voice_state_payload)
else:
if not suppress:
voice_state_payload["request_to_speak_timestamp"] = datetime.datetime.utcnow().isoformat()
voice_state_payload['request_to_speak_timestamp'] = datetime.datetime.utcnow().isoformat()
await http.edit_voice_state(guild_id, self.id, voice_state_payload)
if voice_channel is not MISSING:
payload["channel_id"] = voice_channel and voice_channel.id
try:
vc = fields['voice_channel']
except KeyError:
pass
else:
payload['channel_id'] = vc and vc.id
if roles is not MISSING:
payload["roles"] = tuple(r.id for r in roles)
try:
roles = fields['roles']
except KeyError:
pass
else:
payload['roles'] = tuple(r.id for r in roles)
if payload:
data = await http.edit_member(guild_id, self.id, reason=reason, **payload)
return Member(data=data, guild=self.guild, state=self._state)
await http.edit_member(guild_id, self.id, reason=reason, **payload)
async def request_to_speak(self) -> None:
# TODO: wait for WS event for modify-in-place behaviour
async def request_to_speak(self):
"""|coro|
Request to speak in the connected channel.
@ -784,12 +714,12 @@ class Member(discord.abc.Messageable, _UserTag):
The operation failed.
"""
payload = {
"channel_id": self.voice.channel.id,
"request_to_speak_timestamp": datetime.datetime.utcnow().isoformat(),
'channel_id': self.voice.channel.id,
'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(),
}
if self._state.self_id != self.id:
payload["suppress"] = False
payload['suppress'] = False
await self._state.http.edit_voice_state(self.guild.id, self.id, payload)
else:
await self._state.http.edit_my_voice_state(self.guild.id, payload)
@ -817,7 +747,7 @@ class Member(discord.abc.Messageable, _UserTag):
"""
await self.edit(voice_channel=channel, reason=reason)
async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None:
async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True):
r"""|coro|
Gives the member a number of :class:`Role`\s.
@ -886,7 +816,7 @@ class Member(discord.abc.Messageable, _UserTag):
"""
if not atomic:
new_roles = [Object(id=r.id) for r in self.roles[1:]] # remove @everyone
new_roles = [Object(id=r.id) for r in self.roles[1:]] # remove @everyone
for role in roles:
try:
new_roles.remove(Object(id=role.id))
@ -901,7 +831,7 @@ class Member(discord.abc.Messageable, _UserTag):
for role in roles:
await req(guild_id, user_id, role.id, reason=reason)
def get_role(self, role_id: int, /) -> Optional[Role]:
def get_role(self, role_id: int) -> Optional[discord.Role]:
"""Returns a role with the given ID from roles which the member has.
.. versionadded:: 2.0

View File

@ -25,7 +25,9 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union
__all__ = ("AllowedMentions",)
__all__ = (
'AllowedMentions',
)
if TYPE_CHECKING:
from .types.message import AllowedMentions as AllowedMentionsPayload
@ -34,7 +36,7 @@ if TYPE_CHECKING:
class _FakeBool:
def __repr__(self):
return "True"
return 'True'
def __eq__(self, other):
return other is True
@ -45,7 +47,7 @@ class _FakeBool:
default: Any = _FakeBool()
A = TypeVar("A", bound="AllowedMentions")
A = TypeVar('A', bound='AllowedMentions')
class AllowedMentions:
@ -78,7 +80,7 @@ class AllowedMentions:
.. versionadded:: 1.6
"""
__slots__ = ("everyone", "users", "roles", "replied_user")
__slots__ = ('everyone', 'users', 'roles', 'replied_user')
def __init__(
self,
@ -114,22 +116,22 @@ class AllowedMentions:
data = {}
if self.everyone:
parse.append("everyone")
parse.append('everyone')
if self.users == True:
parse.append("users")
parse.append('users')
elif self.users != False:
data["users"] = [x.id for x in self.users]
data['users'] = [x.id for x in self.users]
if self.roles == True:
parse.append("roles")
parse.append('roles')
elif self.roles != False:
data["roles"] = [x.id for x in self.roles]
data['roles'] = [x.id for x in self.roles]
if self.replied_user:
data["replied_user"] = True
data['replied_user'] = True
data["parse"] = parse
data['parse'] = parse
return data # type: ignore
def merge(self, other: AllowedMentions) -> AllowedMentions:
@ -144,6 +146,6 @@ class AllowedMentions:
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(everyone={self.everyone}, "
f"users={self.users}, roles={self.roles}, replied_user={self.replied_user})"
f'{self.__class__.__name__}(everyone={self.everyone}, '
f'users={self.users}, roles={self.roles}, replied_user={self.replied_user})'
)

File diff suppressed because it is too large Load Diff

View File

@ -22,31 +22,30 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import TypeVar
__all__ = (
"EqualityComparable",
"Hashable",
'EqualityComparable',
'Hashable',
)
E = TypeVar('E', bound='EqualityComparable')
class EqualityComparable:
__slots__ = ()
id: int
def __eq__(self, other: object) -> bool:
def __eq__(self: E, other: E) -> bool:
return isinstance(other, self.__class__) and other.id == self.id
def __ne__(self, other: object) -> bool:
def __ne__(self: E, other: E) -> bool:
if isinstance(other, self.__class__):
return other.id != self.id
return True
class Hashable(EqualityComparable):
__slots__ = ()
def __int__(self) -> int:
return self.id
def __hash__(self) -> int:
return self.id >> 22

View File

@ -35,11 +35,11 @@ from typing import (
if TYPE_CHECKING:
import datetime
SupportsIntCast = Union[SupportsInt, str, bytes, bytearray]
__all__ = ("Object",)
__all__ = (
'Object',
)
class Object(Hashable):
"""Represents a generic Discord object.
@ -69,10 +69,6 @@ class Object(Hashable):
Returns the object's hash.
.. describe:: int(x)
Returns the object's ID.
Attributes
-----------
id: :class:`int`
@ -83,12 +79,12 @@ class Object(Hashable):
try:
id = int(id)
except ValueError:
raise TypeError(f"id parameter must be convertable to int not {id.__class__!r}") from None
raise TypeError(f'id parameter must be convertable to int not {id.__class__!r}') from None
else:
self.id = id
def __repr__(self) -> str:
return f"<Object id={self.id!r}>"
return f'<Object id={self.id!r}>'
@property
def created_at(self) -> datetime.datetime:

View File

@ -22,54 +22,40 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import struct
from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional
from .errors import DiscordException
__all__ = (
"OggError",
"OggPage",
"OggStream",
'OggError',
'OggPage',
'OggStream',
)
class OggError(DiscordException):
"""An exception that is thrown for Ogg stream parsing errors."""
pass
# https://tools.ietf.org/html/rfc3533
# https://tools.ietf.org/html/rfc7845
class OggPage:
_header: ClassVar[struct.Struct] = struct.Struct("<xBQIIIB")
if TYPE_CHECKING:
flag: int
gran_pos: int
serial: int
pagenum: int
crc: int
segnum: int
_header = struct.Struct('<xBQIIIB')
def __init__(self, stream: IO[bytes]) -> None:
def __init__(self, stream):
try:
header = stream.read(struct.calcsize(self._header.format))
self.flag, self.gran_pos, self.serial, self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.flag, self.gran_pos, self.serial, \
self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.segtable: bytes = stream.read(self.segnum)
bodylen = sum(struct.unpack("B" * self.segnum, self.segtable))
self.data: bytes = stream.read(bodylen)
self.segtable = stream.read(self.segnum)
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable))
self.data = stream.read(bodylen)
except Exception:
raise OggError("bad data stream") from None
raise OggError('bad data stream') from None
def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]:
def iter_packets(self):
packetlen = offset = 0
partial = True
@ -79,7 +65,7 @@ class OggPage:
partial = True
else:
packetlen += seg
yield self.data[offset : offset + packetlen], True
yield self.data[offset:offset+packetlen], True
offset += packetlen
packetlen = 0
partial = False
@ -87,31 +73,30 @@ class OggPage:
if partial:
yield self.data[offset:], False
class OggStream:
def __init__(self, stream: IO[bytes]) -> None:
self.stream: IO[bytes] = stream
def __init__(self, stream):
self.stream = stream
def _next_page(self) -> Optional[OggPage]:
def _next_page(self):
head = self.stream.read(4)
if head == b"OggS":
if head == b'OggS':
return OggPage(self.stream)
elif not head:
return None
else:
raise OggError("invalid header magic")
raise OggError('invalid header magic')
def _iter_pages(self) -> Generator[OggPage, None, None]:
def _iter_pages(self):
page = self._next_page()
while page:
yield page
page = self._next_page()
def iter_packets(self) -> Generator[bytes, None, None]:
partial = b""
def iter_packets(self):
partial = b''
for page in self._iter_pages():
for data, complete in page.iter_packets():
partial += data
if complete:
yield partial
partial = b""
partial = b''

View File

@ -22,10 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload
import array
import ctypes
import ctypes.util
@ -35,157 +31,134 @@ import os.path
import struct
import sys
from .errors import DiscordException, InvalidArgument
if TYPE_CHECKING:
T = TypeVar("T")
BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"]
SIGNAL_CTL = Literal["auto", "voice", "music"]
class BandCtl(TypedDict):
narrow: int
medium: int
wide: int
superwide: int
full: int
class SignalCtl(TypedDict):
auto: int
voice: int
music: int
from .errors import DiscordException
__all__ = (
"Encoder",
"OpusError",
"OpusNotLoaded",
'Encoder',
'OpusError',
'OpusNotLoaded',
)
_log = logging.getLogger(__name__)
log = logging.getLogger(__name__)
c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
c_float_ptr = ctypes.POINTER(ctypes.c_float)
_lib = None
class EncoderStruct(ctypes.Structure):
pass
class DecoderStruct(ctypes.Structure):
pass
EncoderStructPtr = ctypes.POINTER(EncoderStruct)
DecoderStructPtr = ctypes.POINTER(DecoderStruct)
## Some constants from opus_defines.h
# Error codes
OK = 0
OK = 0
BAD_ARG = -1
# Encoder CTLs
APPLICATION_AUDIO = 2049
APPLICATION_VOIP = 2048
APPLICATION_AUDIO = 2049
APPLICATION_VOIP = 2048
APPLICATION_LOWDELAY = 2051
CTL_SET_BITRATE = 4002
CTL_SET_BANDWIDTH = 4008
CTL_SET_FEC = 4012
CTL_SET_PLP = 4014
CTL_SET_SIGNAL = 4024
CTL_SET_BITRATE = 4002
CTL_SET_BANDWIDTH = 4008
CTL_SET_FEC = 4012
CTL_SET_PLP = 4014
CTL_SET_SIGNAL = 4024
# Decoder CTLs
CTL_SET_GAIN = 4034
CTL_SET_GAIN = 4034
CTL_LAST_PACKET_DURATION = 4039
band_ctl: BandCtl = {
"narrow": 1101,
"medium": 1102,
"wide": 1103,
"superwide": 1104,
"full": 1105,
band_ctl = {
'narrow': 1101,
'medium': 1102,
'wide': 1103,
'superwide': 1104,
'full': 1105,
}
signal_ctl: SignalCtl = {
"auto": -1000,
"voice": 3001,
"music": 3002,
signal_ctl = {
'auto': -1000,
'voice': 3001,
'music': 3002,
}
def _err_lt(result: int, func: Callable, args: List) -> int:
def _err_lt(result, func, args):
if result < OK:
_log.info("error has happened in %s", func.__name__)
log.info('error has happened in %s', func.__name__)
raise OpusError(result)
return result
def _err_ne(result: T, func: Callable, args: List) -> T:
def _err_ne(result, func, args):
ret = args[-1]._obj
if ret.value != OK:
_log.info("error has happened in %s", func.__name__)
log.info('error has happened in %s', func.__name__)
raise OpusError(ret.value)
return result
# A list of exported functions.
# The first argument is obviously the name.
# The second one are the types of arguments it takes.
# The third is the result type.
# The fourth is the error handler.
exported_functions: List[Tuple[Any, ...]] = [
exported_functions = [
# Generic
("opus_get_version_string", None, ctypes.c_char_p, None),
("opus_strerror", [ctypes.c_int], ctypes.c_char_p, None),
('opus_get_version_string',
None, ctypes.c_char_p, None),
('opus_strerror',
[ctypes.c_int], ctypes.c_char_p, None),
# Encoder functions
("opus_encoder_get_size", [ctypes.c_int], ctypes.c_int, None),
("opus_encoder_create", [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne),
(
"opus_encode",
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
ctypes.c_int32,
_err_lt,
),
(
"opus_encode_float",
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32],
ctypes.c_int32,
_err_lt,
),
("opus_encoder_ctl", None, ctypes.c_int32, _err_lt),
("opus_encoder_destroy", [EncoderStructPtr], None, None),
('opus_encoder_get_size',
[ctypes.c_int], ctypes.c_int, None),
('opus_encoder_create',
[ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne),
('opus_encode',
[EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt),
('opus_encode_float',
[EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt),
('opus_encoder_ctl',
None, ctypes.c_int32, _err_lt),
('opus_encoder_destroy',
[EncoderStructPtr], None, None),
# Decoder functions
("opus_decoder_get_size", [ctypes.c_int], ctypes.c_int, None),
("opus_decoder_create", [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne),
(
"opus_decode",
('opus_decoder_get_size',
[ctypes.c_int], ctypes.c_int, None),
('opus_decoder_create',
[ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne),
('opus_decode',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int,
_err_lt,
),
(
"opus_decode_float",
ctypes.c_int, _err_lt),
('opus_decode_float',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int],
ctypes.c_int,
_err_lt,
),
("opus_decoder_ctl", None, ctypes.c_int32, _err_lt),
("opus_decoder_destroy", [DecoderStructPtr], None, None),
("opus_decoder_get_nb_samples", [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
ctypes.c_int, _err_lt),
('opus_decoder_ctl',
None, ctypes.c_int32, _err_lt),
('opus_decoder_destroy',
[DecoderStructPtr], None, None),
('opus_decoder_get_nb_samples',
[DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt),
# Packet functions
("opus_packet_get_bandwidth", [ctypes.c_char_p], ctypes.c_int, _err_lt),
("opus_packet_get_nb_channels", [ctypes.c_char_p], ctypes.c_int, _err_lt),
("opus_packet_get_nb_frames", [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
("opus_packet_get_samples_per_frame", [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
('opus_packet_get_bandwidth',
[ctypes.c_char_p], ctypes.c_int, _err_lt),
('opus_packet_get_nb_channels',
[ctypes.c_char_p], ctypes.c_int, _err_lt),
('opus_packet_get_nb_frames',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
('opus_packet_get_samples_per_frame',
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
]
def libopus_loader(name: str) -> Any:
def libopus_loader(name):
# create the library...
lib = ctypes.cdll.LoadLibrary(name)
@ -205,29 +178,27 @@ def libopus_loader(name: str) -> Any:
if item[3]:
func.errcheck = item[3]
except KeyError:
_log.exception("Error assigning check function to %s", func)
log.exception("Error assigning check function to %s", func)
return lib
def _load_default() -> bool:
def _load_default():
global _lib
try:
if sys.platform == "win32":
if sys.platform == 'win32':
_basedir = os.path.dirname(os.path.abspath(__file__))
_bitness = struct.calcsize("P") * 8
_target = "x64" if _bitness > 32 else "x86"
_filename = os.path.join(_basedir, "bin", f"libopus-0.{_target}.dll")
_bitness = struct.calcsize('P') * 8
_target = 'x64' if _bitness > 32 else 'x86'
_filename = os.path.join(_basedir, 'bin', f'libopus-0.{_target}.dll')
_lib = libopus_loader(_filename)
else:
_lib = libopus_loader(ctypes.util.find_library("opus"))
_lib = libopus_loader(ctypes.util.find_library('opus'))
except Exception:
_lib = None
return _lib is not None
def load_opus(name: str) -> None:
def load_opus(name):
"""Loads the libopus shared library for use with voice.
If this function is not called then the library uses the function
@ -265,8 +236,7 @@ def load_opus(name: str) -> None:
global _lib
_lib = libopus_loader(name)
def is_loaded() -> bool:
def is_loaded():
"""Function to check if opus lib is successfully loaded either
via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
@ -280,7 +250,6 @@ def is_loaded() -> bool:
global _lib
return _lib is not None
class OpusError(DiscordException):
"""An exception that is thrown for libopus related errors.
@ -290,24 +259,21 @@ class OpusError(DiscordException):
The error code returned.
"""
def __init__(self, code: int):
self.code: int = code
msg = _lib.opus_strerror(self.code).decode("utf-8")
_log.info('"%s" has happened', msg)
def __init__(self, code):
self.code = code
msg = _lib.opus_strerror(self.code).decode('utf-8')
log.info('"%s" has happened', msg)
super().__init__(msg)
class OpusNotLoaded(DiscordException):
"""An exception that is thrown for when libopus is not loaded."""
pass
class _OpusStruct:
SAMPLING_RATE = 48000
CHANNELS = 2
FRAME_LENGTH = 20 # in milliseconds
SAMPLE_SIZE = struct.calcsize("h") * CHANNELS
SAMPLE_SIZE = struct.calcsize('h') * CHANNELS
SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH)
FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE
@ -317,101 +283,95 @@ class _OpusStruct:
if not is_loaded() and not _load_default():
raise OpusNotLoaded()
return _lib.opus_get_version_string().decode("utf-8")
return _lib.opus_get_version_string().decode('utf-8')
class Encoder(_OpusStruct):
def __init__(self, application: int = APPLICATION_AUDIO):
def __init__(self, application=APPLICATION_AUDIO):
_OpusStruct.get_opus_version()
self.application: int = application
self._state: EncoderStruct = self._create_state()
self.application = application
self._state = self._create_state()
self.set_bitrate(128)
self.set_fec(True)
self.set_expected_packet_loss_percent(0.15)
self.set_bandwidth("full")
self.set_signal_type("auto")
self.set_bandwidth('full')
self.set_signal_type('auto')
def __del__(self) -> None:
if hasattr(self, "_state"):
def __del__(self):
if hasattr(self, '_state'):
_lib.opus_encoder_destroy(self._state)
# This is a destructor, so it's okay to assign None
self._state = None # type: ignore
self._state = None
def _create_state(self) -> EncoderStruct:
def _create_state(self):
ret = ctypes.c_int()
return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret))
def set_bitrate(self, kbps: int) -> int:
def set_bitrate(self, kbps):
kbps = min(512, max(16, int(kbps)))
_lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024)
return kbps
def set_bandwidth(self, req: BAND_CTL) -> None:
def set_bandwidth(self, req):
if req not in band_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}')
k = band_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k)
def set_signal_type(self, req: SIGNAL_CTL) -> None:
def set_signal_type(self, req):
if req not in signal_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}')
k = signal_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k)
def set_fec(self, enabled: bool = True) -> None:
def set_fec(self, enabled=True):
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage: float) -> None:
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore
def set_expected_packet_loss_percent(self, percentage):
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
def encode(self, pcm: bytes, frame_size: int) -> bytes:
def encode(self, pcm, frame_size):
max_data_bytes = len(pcm)
# bytes can be used to reference pointer
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
pcm = ctypes.cast(pcm, c_int16_ptr)
data = (ctypes.c_char * max_data_bytes)()
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know
return array.array("b", data[:ret]).tobytes() # type: ignore
ret = _lib.opus_encode(self._state, pcm, frame_size, data, max_data_bytes)
return array.array('b', data[:ret]).tobytes()
class Decoder(_OpusStruct):
def __init__(self):
_OpusStruct.get_opus_version()
self._state: DecoderStruct = self._create_state()
self._state = self._create_state()
def __del__(self) -> None:
if hasattr(self, "_state"):
def __del__(self):
if hasattr(self, '_state'):
_lib.opus_decoder_destroy(self._state)
# This is a destructor, so it's okay to assign None
self._state = None # type: ignore
self._state = None
def _create_state(self) -> DecoderStruct:
def _create_state(self):
ret = ctypes.c_int()
return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret))
@staticmethod
def packet_get_nb_frames(data: bytes) -> int:
def packet_get_nb_frames(data):
"""Gets the number of frames in an Opus packet"""
return _lib.opus_packet_get_nb_frames(data, len(data))
@staticmethod
def packet_get_nb_channels(data: bytes) -> int:
def packet_get_nb_channels(data):
"""Gets the number of channels in an Opus packet"""
return _lib.opus_packet_get_nb_channels(data)
@classmethod
def packet_get_samples_per_frame(cls, data: bytes) -> int:
def packet_get_samples_per_frame(cls, data):
"""Gets the number of samples per frame from an Opus packet"""
return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE)
def _set_gain(self, adjustment: int) -> int:
def _set_gain(self, adjustment):
"""Configures decoder gain adjustment.
Scales the decoded output by a factor specified in Q8 dB units.
@ -423,34 +383,26 @@ class Decoder(_OpusStruct):
"""
return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment)
def set_gain(self, dB: float) -> int:
def set_gain(self, dB):
"""Sets the decoder gain in dB, from -128 to 128."""
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
return self._set_gain(dB_Q8)
def set_volume(self, mult: float) -> int:
def set_volume(self, mult):
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
return self.set_gain(20 * math.log10(mult)) # amplitude ratio
return self.set_gain(20 * math.log10(mult)) # amplitude ratio
def _get_last_packet_duration(self) -> int:
def _get_last_packet_duration(self):
"""Gets the duration (in samples) of the last packet successfully decoded or concealed."""
ret = ctypes.c_int32()
_lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret))
return ret.value
@overload
def decode(self, data: bytes, *, fec: bool) -> bytes:
...
@overload
def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes:
...
def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes:
def decode(self, data, *, fec=False):
if data is None and fec:
raise InvalidArgument("Invalid arguments: FEC cannot be used with null data")
raise OpusError("Invalid arguments: FEC cannot be used with null data")
if data is None:
frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME
@ -466,4 +418,4 @@ class Decoder(_OpusStruct):
ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec)
return array.array("h", pcm[: ret * channel_count]).tobytes()
return array.array('h', pcm[:ret * channel_count]).tobytes()

View File

@ -24,21 +24,22 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Dict, Optional, TYPE_CHECKING, Type, TypeVar, Union
from typing import Any, Dict, Optional, TYPE_CHECKING, Type, TypeVar
import re
from .asset import Asset, AssetMixin
from .errors import InvalidArgument
from . import utils
__all__ = ("PartialEmoji",)
__all__ = (
'PartialEmoji',
)
if TYPE_CHECKING:
from .state import ConnectionState
from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload
class _EmojiTag:
__slots__ = ()
@ -48,7 +49,7 @@ class _EmojiTag:
raise NotImplementedError
PE = TypeVar("PE", bound="PartialEmoji")
PE = TypeVar('PE', bound='PartialEmoji')
class PartialEmoji(_EmojiTag, AssetMixin):
@ -89,9 +90,9 @@ class PartialEmoji(_EmojiTag, AssetMixin):
The ID of the custom emoji, if applicable.
"""
__slots__ = ("animated", "name", "id", "_state")
__slots__ = ('animated', 'name', 'id', '_state')
_CUSTOM_EMOJI_RE = re.compile(r"<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?")
_CUSTOM_EMOJI_RE = re.compile(r'<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?')
if TYPE_CHECKING:
id: Optional[int]
@ -103,11 +104,11 @@ class PartialEmoji(_EmojiTag, AssetMixin):
self._state: Optional[ConnectionState] = None
@classmethod
def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE:
def from_dict(cls: Type[PE], data: Dict[str, Any]) -> PE:
return cls(
animated=data.get("animated", False),
id=utils._get_as_snowflake(data, "id"),
name=data.get("name") or "",
animated=data.get('animated', False),
id=utils._get_as_snowflake(data, 'id'),
name=data.get('name', ''),
)
@classmethod
@ -138,19 +139,19 @@ class PartialEmoji(_EmojiTag, AssetMixin):
match = cls._CUSTOM_EMOJI_RE.match(value)
if match is not None:
groups = match.groupdict()
animated = bool(groups["animated"])
emoji_id = int(groups["id"])
name = groups["name"]
animated = bool(groups['animated'])
emoji_id = int(groups['id'])
name = groups['name']
return cls(name=name, animated=animated, id=emoji_id)
return cls(name=value, id=None, animated=False)
def to_dict(self) -> Dict[str, Any]:
o: Dict[str, Any] = {"name": self.name}
o: Dict[str, Any] = {'name': self.name}
if self.id:
o["id"] = self.id
o['id'] = self.id
if self.animated:
o["animated"] = self.animated
o['animated'] = self.animated
return o
def _to_partial(self) -> PartialEmoji:
@ -168,11 +169,11 @@ class PartialEmoji(_EmojiTag, AssetMixin):
if self.id is None:
return self.name
if self.animated:
return f"<a:{self.name}:{self.id}>"
return f"<:{self.name}:{self.id}>"
return f'<a:{self.name}:{self.id}>'
return f'<:{self.name}:{self.id}>'
def __repr__(self):
return f"<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>"
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>'
def __eq__(self, other: Any) -> bool:
if self.is_unicode_emoji():
@ -199,7 +200,7 @@ class PartialEmoji(_EmojiTag, AssetMixin):
def _as_reaction(self) -> str:
if self.id is None:
return self.name
return f"{self.name}:{self.id}"
return f'{self.name}:{self.id}'
@property
def created_at(self) -> Optional[datetime]:
@ -219,13 +220,13 @@ class PartialEmoji(_EmojiTag, AssetMixin):
If this isn't a custom emoji then an empty string is returned
"""
if self.is_unicode_emoji():
return ""
return ''
fmt = "gif" if self.animated else "png"
return f"{Asset.BASE}/emojis/{self.id}.{fmt}"
fmt = 'gif' if self.animated else 'png'
return f'{Asset.BASE}/emojis/{self.id}.{fmt}'
async def read(self) -> bytes:
if self.is_unicode_emoji():
raise InvalidArgument("PartialEmoji is not a custom emoji")
raise InvalidArgument('PartialEmoji is not a custom emoji')
return await super().read()

View File

@ -22,34 +22,25 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Callable, Any, ClassVar, Dict, Iterator, Set, TYPE_CHECKING, Tuple, Type, TypeVar, Optional
from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value
__all__ = (
"Permissions",
"PermissionOverwrite",
'Permissions',
'PermissionOverwrite',
)
# A permission alias works like a regular flag but is marked
# So the PermissionOverwrite knows to work with it
class permission_alias(alias_flag_value):
alias: str
pass
def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permission_alias]:
def decorator(func: Callable[[Any], int]) -> permission_alias:
def make_permission_alias(alias):
def decorator(func):
ret = permission_alias(func)
ret.alias = alias
return ret
return decorator
P = TypeVar("P", bound="Permissions")
@fill_with_flags()
class Permissions(BaseFlags):
"""Wraps up the Discord permission value.
@ -101,35 +92,35 @@ class Permissions(BaseFlags):
__slots__ = ()
def __init__(self, permissions: int = 0, **kwargs: bool):
def __init__(self, permissions=0, **kwargs):
if not isinstance(permissions, int):
raise TypeError(f"Expected int parameter, received {permissions.__class__.__name__} instead.")
raise TypeError(f'Expected int parameter, received {permissions.__class__.__name__} instead.')
self.value = permissions
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError(f"{key!r} is not a valid permission name.")
raise TypeError(f'{key!r} is not a valid permission name.')
setattr(self, key, value)
def is_subset(self, other: Permissions) -> bool:
def is_subset(self, other):
"""Returns ``True`` if self has the same or fewer permissions as other."""
if isinstance(other, Permissions):
return (self.value & other.value) == self.value
else:
raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}")
def is_superset(self, other: Permissions) -> bool:
def is_superset(self, other):
"""Returns ``True`` if self has the same or more permissions as other."""
if isinstance(other, Permissions):
return (self.value | other.value) == self.value
else:
raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}")
def is_strict_subset(self, other: Permissions) -> bool:
def is_strict_subset(self, other):
"""Returns ``True`` if the permissions on other are a strict subset of those on self."""
return self.is_subset(other) and self != other
def is_strict_superset(self, other: Permissions) -> bool:
def is_strict_superset(self, other):
"""Returns ``True`` if the permissions on other are a strict superset of those on self."""
return self.is_superset(other) and self != other
@ -139,20 +130,20 @@ class Permissions(BaseFlags):
__gt__ = is_strict_superset
@classmethod
def none(cls: Type[P]) -> P:
def none(cls):
"""A factory method that creates a :class:`Permissions` with all
permissions set to ``False``."""
return cls(0)
@classmethod
def all(cls: Type[P]) -> P:
def all(cls):
"""A factory method that creates a :class:`Permissions` with all
permissions set to ``True``.
"""
return cls(0b111111111111111111111111111111111111111)
return cls(0b111111111111111111111111111111111111)
@classmethod
def all_channel(cls: Type[P]) -> P:
def all_channel(cls):
"""A :class:`Permissions` with all channel-specific permissions set to
``True`` and the guild-specific ones set to ``False``. The guild-specific
permissions are currently:
@ -169,16 +160,11 @@ class Permissions(BaseFlags):
.. versionchanged:: 1.7
Added :attr:`stream`, :attr:`priority_speaker` and :attr:`use_slash_commands` permissions.
.. versionchanged:: 2.0
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
:attr:`use_external_stickers`, :attr:`send_messages_in_threads` and
:attr:`request_to_speak` permissions.
"""
return cls(0b111110110110011111101111111111101010001)
return cls(0b10110011111101111111111101010001)
@classmethod
def general(cls: Type[P]) -> P:
def general(cls):
"""A factory method that creates a :class:`Permissions` with all
"General" permissions from the official Discord UI set to ``True``.
@ -191,7 +177,7 @@ class Permissions(BaseFlags):
return cls(0b01110000000010000000010010110000)
@classmethod
def membership(cls: Type[P]) -> P:
def membership(cls):
"""A factory method that creates a :class:`Permissions` with all
"Membership" permissions from the official Discord UI set to ``True``.
@ -200,28 +186,24 @@ class Permissions(BaseFlags):
return cls(0b00001100000000000000000000000111)
@classmethod
def text(cls: Type[P]) -> P:
def text(cls):
"""A factory method that creates a :class:`Permissions` with all
"Text" permissions from the official Discord UI set to ``True``.
.. versionchanged:: 1.7
Permission :attr:`read_messages` is no longer part of the text permissions.
Added :attr:`use_slash_commands` permission.
.. versionchanged:: 2.0
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
:attr:`send_messages_in_threads` and :attr:`use_external_stickers` permissions.
"""
return cls(0b111110010000000000001111111100001000000)
return cls(0b10000000000001111111100001000000)
@classmethod
def voice(cls: Type[P]) -> P:
def voice(cls):
"""A factory method that creates a :class:`Permissions` with all
"Voice" permissions from the official Discord UI set to ``True``."""
return cls(0b00000011111100000000001100000000)
@classmethod
def stage(cls: Type[P]) -> P:
def stage(cls):
"""A factory method that creates a :class:`Permissions` with all
"Stage Channel" permissions from the official Discord UI set to ``True``.
@ -230,7 +212,7 @@ class Permissions(BaseFlags):
return cls(1 << 32)
@classmethod
def stage_moderator(cls: Type[P]) -> P:
def stage_moderator(cls):
"""A factory method that creates a :class:`Permissions` with all
"Stage Moderator" permissions from the official Discord UI set to ``True``.
@ -239,7 +221,7 @@ class Permissions(BaseFlags):
return cls(0b100000001010000000000000000000000)
@classmethod
def advanced(cls: Type[P]) -> P:
def advanced(cls):
"""A factory method that creates a :class:`Permissions` with all
"Advanced" permissions from the official Discord UI set to ``True``.
@ -247,7 +229,7 @@ class Permissions(BaseFlags):
"""
return cls(1 << 3)
def update(self, **kwargs: bool) -> None:
def update(self, **kwargs):
r"""Bulk updates this permission object.
Allows you to set multiple attributes by using keyword
@ -263,7 +245,7 @@ class Permissions(BaseFlags):
if key in self.VALID_FLAGS:
setattr(self, key, value)
def handle_overwrite(self, allow: int, deny: int) -> None:
def handle_overwrite(self, allow, deny):
# Basically this is what's happening here.
# We have an original bit array, e.g. 1010
# Then we have another bit array that is 'denied', e.g. 1111
@ -279,74 +261,67 @@ class Permissions(BaseFlags):
self.value = (self.value & ~deny) | allow
@flag_value
def create_instant_invite(self) -> int:
def create_instant_invite(self):
""":class:`bool`: Returns ``True`` if the user can create instant invites."""
return 1 << 0
@flag_value
def kick_members(self) -> int:
def kick_members(self):
""":class:`bool`: Returns ``True`` if the user can kick users from the guild."""
return 1 << 1
@flag_value
def ban_members(self) -> int:
def ban_members(self):
""":class:`bool`: Returns ``True`` if a user can ban users from the guild."""
return 1 << 2
@flag_value
def administrator(self) -> int:
def administrator(self):
""":class:`bool`: Returns ``True`` if a user is an administrator. This role overrides all other permissions.
This also bypasses all channel-specific overrides.
"""
return 1 << 3
@make_permission_alias("administrator")
def admin(self) -> int:
""":class:`bool`: An alias for :attr:`administrator`.
.. versionadded:: 2.0
"""
return 1 << 3
@flag_value
def manage_channels(self) -> int:
def manage_channels(self):
""":class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild.
This also corresponds to the "Manage Channel" channel-specific override."""
return 1 << 4
@flag_value
def manage_guild(self) -> int:
def manage_guild(self):
""":class:`bool`: Returns ``True`` if a user can edit guild properties."""
return 1 << 5
@flag_value
def add_reactions(self) -> int:
def add_reactions(self):
""":class:`bool`: Returns ``True`` if a user can add reactions to messages."""
return 1 << 6
@flag_value
def view_audit_log(self) -> int:
def view_audit_log(self):
""":class:`bool`: Returns ``True`` if a user can view the guild's audit log."""
return 1 << 7
@flag_value
def priority_speaker(self) -> int:
def priority_speaker(self):
""":class:`bool`: Returns ``True`` if a user can be more easily heard while talking."""
return 1 << 8
@flag_value
def stream(self) -> int:
def stream(self):
""":class:`bool`: Returns ``True`` if a user can stream in a voice channel."""
return 1 << 9
@flag_value
def read_messages(self) -> int:
def read_messages(self):
""":class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels."""
return 1 << 10
@make_permission_alias("read_messages")
def view_channel(self) -> int:
@make_permission_alias('read_messages')
def view_channel(self):
""":class:`bool`: An alias for :attr:`read_messages`.
.. versionadded:: 1.3
@ -354,17 +329,17 @@ class Permissions(BaseFlags):
return 1 << 10
@flag_value
def send_messages(self) -> int:
def send_messages(self):
""":class:`bool`: Returns ``True`` if a user can send messages from all or specific text channels."""
return 1 << 11
@flag_value
def send_tts_messages(self) -> int:
def send_tts_messages(self):
""":class:`bool`: Returns ``True`` if a user can send TTS messages from all or specific text channels."""
return 1 << 12
@flag_value
def manage_messages(self) -> int:
def manage_messages(self):
""":class:`bool`: Returns ``True`` if a user can delete or pin messages in a text channel.
.. note::
@ -374,32 +349,32 @@ class Permissions(BaseFlags):
return 1 << 13
@flag_value
def embed_links(self) -> int:
def embed_links(self):
""":class:`bool`: Returns ``True`` if a user's messages will automatically be embedded by Discord."""
return 1 << 14
@flag_value
def attach_files(self) -> int:
def attach_files(self):
""":class:`bool`: Returns ``True`` if a user can send files in their messages."""
return 1 << 15
@flag_value
def read_message_history(self) -> int:
def read_message_history(self):
""":class:`bool`: Returns ``True`` if a user can read a text channel's previous messages."""
return 1 << 16
@flag_value
def mention_everyone(self) -> int:
def mention_everyone(self):
""":class:`bool`: Returns ``True`` if a user's @everyone or @here will mention everyone in the text channel."""
return 1 << 17
@flag_value
def external_emojis(self) -> int:
def external_emojis(self):
""":class:`bool`: Returns ``True`` if a user can use emojis from other guilds."""
return 1 << 18
@make_permission_alias("external_emojis")
def use_external_emojis(self) -> int:
@make_permission_alias('external_emojis')
def use_external_emojis(self):
""":class:`bool`: An alias for :attr:`external_emojis`.
.. versionadded:: 1.3
@ -407,7 +382,7 @@ class Permissions(BaseFlags):
return 1 << 18
@flag_value
def view_guild_insights(self) -> int:
def view_guild_insights(self):
""":class:`bool`: Returns ``True`` if a user can view the guild's insights.
.. versionadded:: 1.3
@ -415,55 +390,55 @@ class Permissions(BaseFlags):
return 1 << 19
@flag_value
def connect(self) -> int:
def connect(self):
""":class:`bool`: Returns ``True`` if a user can connect to a voice channel."""
return 1 << 20
@flag_value
def speak(self) -> int:
def speak(self):
""":class:`bool`: Returns ``True`` if a user can speak in a voice channel."""
return 1 << 21
@flag_value
def mute_members(self) -> int:
def mute_members(self):
""":class:`bool`: Returns ``True`` if a user can mute other users."""
return 1 << 22
@flag_value
def deafen_members(self) -> int:
def deafen_members(self):
""":class:`bool`: Returns ``True`` if a user can deafen other users."""
return 1 << 23
@flag_value
def move_members(self) -> int:
def move_members(self):
""":class:`bool`: Returns ``True`` if a user can move users between other voice channels."""
return 1 << 24
@flag_value
def use_voice_activation(self) -> int:
def use_voice_activation(self):
""":class:`bool`: Returns ``True`` if a user can use voice activation in voice channels."""
return 1 << 25
@flag_value
def change_nickname(self) -> int:
def change_nickname(self):
""":class:`bool`: Returns ``True`` if a user can change their nickname in the guild."""
return 1 << 26
@flag_value
def manage_nicknames(self) -> int:
def manage_nicknames(self):
""":class:`bool`: Returns ``True`` if a user can change other user's nickname in the guild."""
return 1 << 27
@flag_value
def manage_roles(self) -> int:
def manage_roles(self):
""":class:`bool`: Returns ``True`` if a user can create or edit roles less than their role's position.
This also corresponds to the "Manage Permissions" channel-specific override.
"""
return 1 << 28
@make_permission_alias("manage_roles")
def manage_permissions(self) -> int:
@make_permission_alias('manage_roles')
def manage_permissions(self):
""":class:`bool`: An alias for :attr:`manage_roles`.
.. versionadded:: 1.3
@ -471,25 +446,17 @@ class Permissions(BaseFlags):
return 1 << 28
@flag_value
def manage_webhooks(self) -> int:
def manage_webhooks(self):
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete webhooks."""
return 1 << 29
@flag_value
def manage_emojis(self) -> int:
def manage_emojis(self):
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis."""
return 1 << 30
@make_permission_alias("manage_emojis")
def manage_emojis_and_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`manage_emojis`.
.. versionadded:: 2.0
"""
return 1 << 30
@flag_value
def use_slash_commands(self) -> int:
def use_slash_commands(self):
""":class:`bool`: Returns ``True`` if a user can use slash commands.
.. versionadded:: 1.7
@ -497,7 +464,7 @@ class Permissions(BaseFlags):
return 1 << 31
@flag_value
def request_to_speak(self) -> int:
def request_to_speak(self):
""":class:`bool`: Returns ``True`` if a user can request to speak in a stage channel.
.. versionadded:: 1.7
@ -505,7 +472,7 @@ class Permissions(BaseFlags):
return 1 << 32
@flag_value
def manage_events(self) -> int:
def manage_events(self):
""":class:`bool`: Returns ``True`` if a user can manage guild events.
.. versionadded:: 2.0
@ -513,7 +480,7 @@ class Permissions(BaseFlags):
return 1 << 33
@flag_value
def manage_threads(self) -> int:
def manage_threads(self):
""":class:`bool`: Returns ``True`` if a user can manage threads.
.. versionadded:: 2.0
@ -521,50 +488,23 @@ class Permissions(BaseFlags):
return 1 << 34
@flag_value
def create_public_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create public threads.
def use_threads(self):
""":class:`bool`: Returns ``True`` if a user can create and participate in public threads.
.. versionadded:: 2.0
"""
return 1 << 35
@flag_value
def create_private_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create private threads.
def use_private_threads(self):
""":class:`bool`: Returns ``True`` if a user can create and participate in private threads.
.. versionadded:: 2.0
"""
return 1 << 36
@flag_value
def external_stickers(self) -> int:
""":class:`bool`: Returns ``True`` if a user can use stickers from other guilds.
.. versionadded:: 2.0
"""
return 1 << 37
@make_permission_alias("external_stickers")
def use_external_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`external_stickers`.
.. versionadded:: 2.0
"""
return 1 << 37
@flag_value
def send_messages_in_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can send messages in threads.
.. versionadded:: 2.0
"""
return 1 << 38
PO = TypeVar("PO", bound="PermissionOverwrite")
def _augment_from_permissions(cls):
def augment_from_permissions(cls):
cls.VALID_NAMES = set(Permissions.VALID_FLAGS)
aliases = set()
@ -581,7 +521,6 @@ def _augment_from_permissions(cls):
# god bless Python
def getter(self, x=key):
return self._values.get(x)
def setter(self, value, x=key):
self._set(x, value)
@ -591,8 +530,7 @@ def _augment_from_permissions(cls):
cls.PURE_FLAGS = cls.VALID_NAMES - aliases
return cls
@_augment_from_permissions
@augment_from_permissions
class PermissionOverwrite:
r"""A type that is used to represent a channel specific permission.
@ -625,79 +563,30 @@ class PermissionOverwrite:
Set the value of permissions by their name.
"""
__slots__ = ("_values",)
__slots__ = ('_values',)
if TYPE_CHECKING:
VALID_NAMES: ClassVar[Set[str]]
PURE_FLAGS: ClassVar[Set[str]]
# I wish I didn't have to do this
create_instant_invite: Optional[bool]
kick_members: Optional[bool]
ban_members: Optional[bool]
administrator: Optional[bool]
manage_channels: Optional[bool]
manage_guild: Optional[bool]
add_reactions: Optional[bool]
view_audit_log: Optional[bool]
priority_speaker: Optional[bool]
stream: Optional[bool]
read_messages: Optional[bool]
view_channel: Optional[bool]
send_messages: Optional[bool]
send_tts_messages: Optional[bool]
manage_messages: Optional[bool]
embed_links: Optional[bool]
attach_files: Optional[bool]
read_message_history: Optional[bool]
mention_everyone: Optional[bool]
external_emojis: Optional[bool]
use_external_emojis: Optional[bool]
view_guild_insights: Optional[bool]
connect: Optional[bool]
speak: Optional[bool]
mute_members: Optional[bool]
deafen_members: Optional[bool]
move_members: Optional[bool]
use_voice_activation: Optional[bool]
change_nickname: Optional[bool]
manage_nicknames: Optional[bool]
manage_roles: Optional[bool]
manage_permissions: Optional[bool]
manage_webhooks: Optional[bool]
manage_emojis: Optional[bool]
manage_emojis_and_stickers: Optional[bool]
use_slash_commands: Optional[bool]
request_to_speak: Optional[bool]
manage_events: Optional[bool]
manage_threads: Optional[bool]
create_public_threads: Optional[bool]
create_private_threads: Optional[bool]
send_messages_in_threads: Optional[bool]
external_stickers: Optional[bool]
use_external_stickers: Optional[bool]
def __init__(self, **kwargs: Optional[bool]):
self._values: Dict[str, Optional[bool]] = {}
def __init__(self, **kwargs):
self._values = {}
for key, value in kwargs.items():
if key not in self.VALID_NAMES:
raise ValueError(f"no permission called {key}.")
raise ValueError(f'no permission called {key}.')
setattr(self, key, value)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other):
return isinstance(other, PermissionOverwrite) and self._values == other._values
def _set(self, key: str, value: Optional[bool]) -> None:
def _set(self, key, value):
if value not in (True, None, False):
raise TypeError(f"Expected bool or NoneType, received {value.__class__.__name__}")
raise TypeError(f'Expected bool or NoneType, received {value.__class__.__name__}')
if value is None:
self._values.pop(key, None)
else:
self._values[key] = value
def pair(self) -> Tuple[Permissions, Permissions]:
def pair(self):
"""Tuple[:class:`Permissions`, :class:`Permissions`]: Returns the (allow, deny) pair from this overwrite."""
allow = Permissions.none()
@ -712,7 +601,7 @@ class PermissionOverwrite:
return allow, deny
@classmethod
def from_pair(cls: Type[PO], allow: Permissions, deny: Permissions) -> PO:
def from_pair(cls, allow, deny):
"""Creates an overwrite from an allow/deny pair of :class:`Permissions`."""
ret = cls()
for key, value in allow:
@ -725,7 +614,7 @@ class PermissionOverwrite:
return ret
def is_empty(self) -> bool:
def is_empty(self):
"""Checks if the permission overwrite is currently empty.
An empty permission overwrite is one that has no overwrites set
@ -738,7 +627,7 @@ class PermissionOverwrite:
"""
return len(self._values) == 0
def update(self, **kwargs: bool) -> None:
def update(self, **kwargs):
r"""Bulk updates this permission overwrite object.
Allows you to set multiple attributes by using keyword
@ -756,6 +645,6 @@ class PermissionOverwrite:
setattr(self, key, value)
def __iter__(self) -> Iterator[Tuple[str, Optional[bool]]]:
def __iter__(self):
for key in self.PURE_FLAGS:
yield key, self._values.get(key)

View File

@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import threading
import traceback
@ -34,41 +33,27 @@ import time
import json
import sys
import re
import io
from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from .errors import ClientException
from .opus import Encoder as OpusEncoder
from .oggparse import OggStream
from .utils import MISSING
if TYPE_CHECKING:
from .voice_client import VoiceClient
AT = TypeVar("AT", bound="AudioSource")
FT = TypeVar("FT", bound="FFmpegOpusAudio")
_log = logging.getLogger(__name__)
log = logging.getLogger(__name__)
__all__ = (
"AudioSource",
"PCMAudio",
"FFmpegAudio",
"FFmpegPCMAudio",
"FFmpegOpusAudio",
"PCMVolumeTransformer",
'AudioSource',
'PCMAudio',
'FFmpegAudio',
'FFmpegPCMAudio',
'FFmpegOpusAudio',
'PCMVolumeTransformer',
)
CREATE_NO_WINDOW: int
if sys.platform != "win32":
if sys.platform != 'win32':
CREATE_NO_WINDOW = 0
else:
CREATE_NO_WINDOW = 0x08000000
class AudioSource:
"""Represents an audio stream.
@ -80,7 +65,7 @@ class AudioSource:
The audio source reads are done in a separate thread.
"""
def read(self) -> bytes:
def read(self):
"""Reads 20ms worth of audio.
Subclasses must implement this.
@ -100,11 +85,11 @@ class AudioSource:
"""
raise NotImplementedError
def is_opus(self) -> bool:
def is_opus(self):
"""Checks if the audio source is already encoded in Opus."""
return False
def cleanup(self) -> None:
def cleanup(self):
"""Called when clean-up is needed to be done.
Useful for clearing buffer data or processes after
@ -112,10 +97,9 @@ class AudioSource:
"""
pass
def __del__(self) -> None:
def __del__(self):
self.cleanup()
class PCMAudio(AudioSource):
"""Represents raw 16-bit 48KHz stereo PCM audio source.
@ -124,17 +108,15 @@ class PCMAudio(AudioSource):
stream: :term:`py:file object`
A file-like object that reads byte data representing raw PCM.
"""
def __init__(self, stream):
self.stream = stream
def __init__(self, stream: io.BufferedIOBase) -> None:
self.stream: io.BufferedIOBase = stream
def read(self) -> bytes:
def read(self):
ret = self.stream.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE:
return b""
return b''
return ret
class FFmpegAudio(AudioSource):
"""Represents an FFmpeg (or AVConv) based AudioSource.
@ -144,78 +126,48 @@ class FFmpegAudio(AudioSource):
.. versionadded:: 1.3
"""
def __init__(
self, source: Union[str, io.BufferedIOBase], *, executable: str = "ffmpeg", args: Any, **subprocess_kwargs: Any
):
piping = subprocess_kwargs.get("stdin") == subprocess.PIPE
if piping and isinstance(source, str):
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
def __init__(self, source, *, executable='ffmpeg', args, **subprocess_kwargs):
self._process = self._stdout = None
args = [executable, *args]
kwargs = {"stdout": subprocess.PIPE}
kwargs = {'stdout': subprocess.PIPE}
kwargs.update(subprocess_kwargs)
self._process: subprocess.Popen = self._spawn_process(args, **kwargs)
self._stdout: IO[bytes] = self._process.stdout # type: ignore
self._stdin: Optional[IO[Bytes]] = None
self._pipe_thread: Optional[threading.Thread] = None
self._process = self._spawn_process(args, **kwargs)
self._stdout = self._process.stdout
if piping:
n = f"popen-stdin-writer:{id(self):#x}"
self._stdin = self._process.stdin
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
self._pipe_thread.start()
def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen:
def _spawn_process(self, args, **subprocess_kwargs):
process = None
try:
process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs)
except FileNotFoundError:
executable = args.partition(" ")[0] if isinstance(args, str) else args[0]
raise ClientException(executable + " was not found.") from None
executable = args.partition(' ')[0] if isinstance(args, str) else args[0]
raise ClientException(executable + ' was not found.') from None
except subprocess.SubprocessError as exc:
raise ClientException(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc
raise ClientException(f'Popen failed: {exc.__class__.__name__}: {exc}') from exc
else:
return process
def _kill_process(self) -> None:
def cleanup(self):
proc = self._process
if proc is MISSING:
if proc is None:
return
_log.info("Preparing to terminate ffmpeg process %s.", proc.pid)
log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
try:
proc.kill()
except Exception:
_log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid)
log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid)
if proc.poll() is None:
_log.info("ffmpeg process %s has not terminated. Waiting to terminate...", proc.pid)
log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
proc.communicate()
_log.info("ffmpeg process %s should have terminated with a return code of %s.", proc.pid, proc.returncode)
log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
else:
_log.info("ffmpeg process %s successfully terminated with return code of %s.", proc.pid, proc.returncode)
def _pipe_writer(self, source: io.BufferedIOBase) -> None:
while self._process:
# arbitrarily large read size
data = source.read(8192)
if not data:
self._process.terminate()
return
try:
self._stdin.write(data)
except Exception:
_log.debug("Write error for %s, this is probably not a problem", self, exc_info=True)
# at this point the source data is either exhausted or the process is fubar
self._process.terminate()
return
def cleanup(self) -> None:
self._kill_process()
self._process = self._stdout = self._stdin = MISSING
log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
self._process = self._stdout = None
class FFmpegPCMAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv).
@ -252,43 +204,33 @@ class FFmpegPCMAudio(FFmpegAudio):
The subprocess failed to be created.
"""
def __init__(
self,
source: Union[str, io.BufferedIOBase],
*,
executable: str = "ffmpeg",
pipe: bool = False,
stderr: Optional[IO[str]] = None,
before_options: Optional[str] = None,
options: Optional[str] = None,
) -> None:
def __init__(self, source, *, executable='ffmpeg', pipe=False, stderr=None, before_options=None, options=None):
args = []
subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr}
subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
if isinstance(before_options, str):
args.extend(shlex.split(before_options))
args.append("-i")
args.append("-" if pipe else source)
args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning"))
args.append('-i')
args.append('-' if pipe else source)
args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning'))
if isinstance(options, str):
args.extend(shlex.split(options))
args.append("pipe:1")
args.append('pipe:1')
super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
def read(self) -> bytes:
def read(self):
ret = self._stdout.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE:
return b""
return b''
return ret
def is_opus(self) -> bool:
def is_opus(self):
return False
class FFmpegOpusAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv).
@ -350,65 +292,38 @@ class FFmpegOpusAudio(FFmpegAudio):
The subprocess failed to be created.
"""
def __init__(
self,
source: Union[str, io.BufferedIOBase],
*,
bitrate: int = 128,
codec: Optional[str] = None,
executable: str = "ffmpeg",
pipe=False,
stderr=None,
before_options=None,
options=None,
) -> None:
def __init__(self, source, *, bitrate=128, codec=None, executable='ffmpeg',
pipe=False, stderr=None, before_options=None, options=None):
args = []
subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr}
subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
if isinstance(before_options, str):
args.extend(shlex.split(before_options))
args.append("-i")
args.append("-" if pipe else source)
args.append('-i')
args.append('-' if pipe else source)
codec = "copy" if codec in ("opus", "libopus") else "libopus"
codec = 'copy' if codec in ('opus', 'libopus') else 'libopus'
args.extend(
(
"-map_metadata",
"-1",
"-f",
"opus",
"-c:a",
codec,
"-ar",
"48000",
"-ac",
"2",
"-b:a",
f"{bitrate}k",
"-loglevel",
"warning",
)
)
args.extend(('-map_metadata', '-1',
'-f', 'opus',
'-c:a', codec,
'-ar', '48000',
'-ac', '2',
'-b:a', f'{bitrate}k',
'-loglevel', 'warning'))
if isinstance(options, str):
args.extend(shlex.split(options))
args.append("pipe:1")
args.append('pipe:1')
super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
self._packet_iter = OggStream(self._stdout).iter_packets()
@classmethod
async def from_probe(
cls: Type[FT],
source: str,
*,
method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None,
**kwargs: Any,
) -> FT:
async def from_probe(cls, source, *, method=None, **kwargs):
"""|coro|
A factory method that creates a :class:`FFmpegOpusAudio` after probing
@ -432,6 +347,7 @@ class FFmpegOpusAudio(FFmpegAudio):
def custom_probe(source, executable):
# some analysis code here
return codec, bitrate
source = await discord.FFmpegOpusAudio.from_probe("song.webm", method=custom_probe)
@ -464,18 +380,12 @@ class FFmpegOpusAudio(FFmpegAudio):
An instance of this class.
"""
executable = kwargs.get("executable")
executable = kwargs.get('executable')
codec, bitrate = await cls.probe(source, method=method, executable=executable)
return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore
return cls(source, bitrate=bitrate, codec=codec, **kwargs)
@classmethod
async def probe(
cls,
source: str,
*,
method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None,
executable: Optional[str] = None,
) -> Tuple[Optional[str], Optional[int]]:
async def probe(cls, source, *, method=None, executable=None):
"""|coro|
Probes the input source for bitrate and codec information.
@ -498,16 +408,16 @@ class FFmpegOpusAudio(FFmpegAudio):
Returns
---------
Optional[Tuple[Optional[:class:`str`], Optional[:class:`int`]]]
Tuple[Optional[:class:`str`], Optional[:class:`int`]]
A 2-tuple with the codec and bitrate of the input source.
"""
method = method or "native"
executable = executable or "ffmpeg"
method = method or 'native'
executable = executable or 'ffmpeg'
probefunc = fallback = None
if isinstance(method, str):
probefunc = getattr(cls, "_probe_codec_" + method, None)
probefunc = getattr(cls, '_probe_codec_' + method, None)
if probefunc is None:
raise AttributeError(f"Invalid probe method {method!r}")
@ -518,52 +428,53 @@ class FFmpegOpusAudio(FFmpegAudio):
probefunc = method
fallback = cls._probe_codec_fallback
else:
raise TypeError("Expected str or callable for parameter 'probe', " f"not '{method.__class__.__name__}'")
raise TypeError("Expected str or callable for parameter 'probe', " \
f"not '{method.__class__.__name__}'")
codec = bitrate = None
loop = asyncio.get_event_loop()
try:
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable))
except Exception:
if not fallback:
_log.exception("Probe '%s' using '%s' failed", method, executable)
return # type: ignore
log.exception("Probe '%s' using '%s' failed", method, executable)
return
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
try:
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable))
except Exception:
_log.exception("Fallback probe using '%s' failed", executable)
log.exception("Fallback probe using '%s' failed", executable)
else:
_log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
else:
_log.info("Probe found codec=%s, bitrate=%s", codec, bitrate)
log.info("Probe found codec=%s, bitrate=%s", codec, bitrate)
finally:
return codec, bitrate
@staticmethod
def _probe_codec_native(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]:
exe = executable[:2] + "probe" if executable in ("ffmpeg", "avconv") else executable
args = [exe, "-v", "quiet", "-print_format", "json", "-show_streams", "-select_streams", "a:0", source]
def _probe_codec_native(source, executable='ffmpeg'):
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
output = subprocess.check_output(args, timeout=20)
codec = bitrate = None
if output:
data = json.loads(output)
streamdata = data["streams"][0]
streamdata = data['streams'][0]
codec = streamdata.get("codec_name")
bitrate = int(streamdata.get("bit_rate", 0))
bitrate = max(round(bitrate / 1000), 512)
codec = streamdata.get('codec_name')
bitrate = int(streamdata.get('bit_rate', 0))
bitrate = max(round(bitrate/1000, 0), 512)
return codec, bitrate
@staticmethod
def _probe_codec_fallback(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]:
args = [executable, "-hide_banner", "-i", source]
def _probe_codec_fallback(source, executable='ffmpeg'):
args = [executable, '-hide_banner', '-i', source]
proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
out, _ = proc.communicate(timeout=20)
output = out.decode("utf8")
output = out.decode('utf8')
codec = bitrate = None
codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output)
@ -576,14 +487,13 @@ class FFmpegOpusAudio(FFmpegAudio):
return codec, bitrate
def read(self) -> bytes:
return next(self._packet_iter, b"")
def read(self):
return next(self._packet_iter, b'')
def is_opus(self) -> bool:
def is_opus(self):
return True
class PCMVolumeTransformer(AudioSource, Generic[AT]):
class PCMVolumeTransformer(AudioSource):
"""Transforms a previous :class:`AudioSource` to have volume controls.
This does not work on audio sources that have :meth:`AudioSource.is_opus`
@ -605,54 +515,53 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
The audio source is opus encoded.
"""
def __init__(self, original: AT, volume: float = 1.0):
def __init__(self, original, volume=1.0):
if not isinstance(original, AudioSource):
raise TypeError(f"expected AudioSource not {original.__class__.__name__}.")
raise TypeError(f'expected AudioSource not {original.__class__.__name__}.')
if original.is_opus():
raise ClientException("AudioSource must not be Opus encoded.")
raise ClientException('AudioSource must not be Opus encoded.')
self.original: AT = original
self.original = original
self.volume = volume
@property
def volume(self) -> float:
def volume(self):
"""Retrieves or sets the volume as a floating point percentage (e.g. ``1.0`` for 100%)."""
return self._volume
@volume.setter
def volume(self, value: float) -> None:
def volume(self, value):
self._volume = max(value, 0.0)
def cleanup(self) -> None:
def cleanup(self):
self.original.cleanup()
def read(self) -> bytes:
def read(self):
ret = self.original.read()
return audioop.mul(ret, 2, min(self._volume, 2.0))
class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0
DELAY = OpusEncoder.FRAME_LENGTH / 1000.0
def __init__(self, source: AudioSource, client: VoiceClient, *, after=None):
def __init__(self, source, client, *, after=None):
threading.Thread.__init__(self)
self.daemon: bool = True
self.source: AudioSource = source
self.client: VoiceClient = client
self.after: Optional[Callable[[Optional[Exception]], Any]] = after
self.daemon = True
self.source = source
self.client = client
self.after = after
self._end: threading.Event = threading.Event()
self._resumed: threading.Event = threading.Event()
self._resumed.set() # we are not paused
self._current_error: Optional[Exception] = None
self._connected: threading.Event = client._connected
self._lock: threading.Lock = threading.Lock()
self._end = threading.Event()
self._resumed = threading.Event()
self._resumed.set() # we are not paused
self._current_error = None
self._connected = client._connected
self._lock = threading.Lock()
if after is not None and not callable(after):
raise TypeError('Expected a callable for the "after" parameter.')
def _do_run(self) -> None:
def _do_run(self):
self.loops = 0
self._start = time.perf_counter()
@ -687,7 +596,7 @@ class AudioPlayer(threading.Thread):
delay = max(0, self.DELAY + (next_time - time.perf_counter()))
time.sleep(delay)
def run(self) -> None:
def run(self):
try:
self._do_run()
except Exception as exc:
@ -697,53 +606,53 @@ class AudioPlayer(threading.Thread):
self.source.cleanup()
self._call_after()
def _call_after(self) -> None:
def _call_after(self):
error = self._current_error
if self.after is not None:
try:
self.after(error)
except Exception as exc:
_log.exception("Calling the after function failed.")
log.exception('Calling the after function failed.')
exc.__context__ = error
traceback.print_exception(type(exc), exc, exc.__traceback__)
elif error:
msg = f"Exception in voice thread {self.name}"
_log.exception(msg, exc_info=error)
msg = f'Exception in voice thread {self.name}'
log.exception(msg, exc_info=error)
print(msg, file=sys.stderr)
traceback.print_exception(type(error), error, error.__traceback__)
def stop(self) -> None:
def stop(self):
self._end.set()
self._resumed.set()
self._speak(False)
def pause(self, *, update_speaking: bool = True) -> None:
def pause(self, *, update_speaking=True):
self._resumed.clear()
if update_speaking:
self._speak(False)
def resume(self, *, update_speaking: bool = True) -> None:
def resume(self, *, update_speaking=True):
self.loops = 0
self._start = time.perf_counter()
self._resumed.set()
if update_speaking:
self._speak(True)
def is_playing(self) -> bool:
def is_playing(self):
return self._resumed.is_set() and not self._end.is_set()
def is_paused(self) -> bool:
def is_paused(self):
return not self._end.is_set() and not self._resumed.is_set()
def _set_source(self, source: AudioSource) -> None:
def _set_source(self, source):
with self._lock:
self.pause(update_speaking=False)
self.source = source
self.resume(update_speaking=False)
def _speak(self, speaking: bool) -> None:
def _speak(self, speaking):
try:
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop)
except Exception as e:
_log.info("Speaking call in player failed: %s", e)
log.info("Speaking call in player failed: %s", e)

View File

View File

@ -22,44 +22,20 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Optional, Set, List
if TYPE_CHECKING:
from .types.raw_models import (
MessageDeleteEvent,
BulkMessageDeleteEvent,
ReactionActionEvent,
MessageUpdateEvent,
ReactionClearEvent,
ReactionClearEmojiEvent,
IntegrationDeleteEvent,
TypingEvent,
)
from .message import Message
from .partial_emoji import PartialEmoji
from .member import Member
__all__ = (
"RawMessageDeleteEvent",
"RawBulkMessageDeleteEvent",
"RawMessageUpdateEvent",
"RawReactionActionEvent",
"RawReactionClearEvent",
"RawReactionClearEmojiEvent",
"RawIntegrationDeleteEvent",
"RawTypingEvent",
'RawMessageDeleteEvent',
'RawBulkMessageDeleteEvent',
'RawMessageUpdateEvent',
'RawReactionActionEvent',
'RawReactionClearEvent',
'RawReactionClearEmojiEvent',
'RawIntegrationDeleteEvent',
)
class _RawReprMixin:
def __repr__(self) -> str:
value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__)
return f"<{self.__class__.__name__} {value}>"
def __repr__(self):
value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__)
return f'<{self.__class__.__name__} {value}>'
class RawMessageDeleteEvent(_RawReprMixin):
"""Represents the event payload for a :func:`on_raw_message_delete` event.
@ -76,17 +52,16 @@ class RawMessageDeleteEvent(_RawReprMixin):
The cached message, if found in the internal message cache.
"""
__slots__ = ("message_id", "channel_id", "guild_id", "cached_message")
__slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message')
def __init__(self, data: MessageDeleteEvent) -> None:
self.message_id: int = int(data["id"])
self.channel_id: int = int(data["channel_id"])
self.cached_message: Optional[Message] = None
def __init__(self, data):
self.message_id = int(data['id'])
self.channel_id = int(data['channel_id'])
self.cached_message = None
try:
self.guild_id: Optional[int] = int(data["guild_id"])
self.guild_id = int(data['guild_id'])
except KeyError:
self.guild_id: Optional[int] = None
self.guild_id = None
class RawBulkMessageDeleteEvent(_RawReprMixin):
"""Represents the event payload for a :func:`on_raw_bulk_message_delete` event.
@ -103,18 +78,17 @@ class RawBulkMessageDeleteEvent(_RawReprMixin):
The cached messages, if found in the internal message cache.
"""
__slots__ = ("message_ids", "channel_id", "guild_id", "cached_messages")
__slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages')
def __init__(self, data: BulkMessageDeleteEvent) -> None:
self.message_ids: Set[int] = {int(x) for x in data.get("ids", [])}
self.channel_id: int = int(data["channel_id"])
self.cached_messages: List[Message] = []
def __init__(self, data):
self.message_ids = {int(x) for x in data.get('ids', [])}
self.channel_id = int(data['channel_id'])
self.cached_messages = []
try:
self.guild_id: Optional[int] = int(data["guild_id"])
self.guild_id = int(data['guild_id'])
except KeyError:
self.guild_id: Optional[int] = None
self.guild_id = None
class RawMessageUpdateEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_message_edit` event.
@ -139,19 +113,18 @@ class RawMessageUpdateEvent(_RawReprMixin):
it is modified by the data in :attr:`RawMessageUpdateEvent.data`.
"""
__slots__ = ("message_id", "channel_id", "guild_id", "data", "cached_message")
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message')
def __init__(self, data: MessageUpdateEvent) -> None:
self.message_id: int = int(data["id"])
self.channel_id: int = int(data["channel_id"])
self.data: MessageUpdateEvent = data
self.cached_message: Optional[Message] = None
def __init__(self, data):
self.message_id = int(data['id'])
self.channel_id = int(data['channel_id'])
self.data = data
self.cached_message = None
try:
self.guild_id: Optional[int] = int(data["guild_id"])
self.guild_id = int(data['guild_id'])
except KeyError:
self.guild_id: Optional[int] = None
self.guild_id = None
class RawReactionActionEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_add` or
@ -182,21 +155,21 @@ class RawReactionActionEvent(_RawReprMixin):
.. versionadded:: 1.3
"""
__slots__ = ("message_id", "user_id", "channel_id", "guild_id", "emoji", "event_type", "member")
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji',
'event_type', 'member')
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None:
self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data["channel_id"])
self.user_id: int = int(data["user_id"])
self.emoji: PartialEmoji = emoji
self.event_type: str = event_type
self.member: Optional[Member] = None
def __init__(self, data, emoji, event_type):
self.message_id = int(data['message_id'])
self.channel_id = int(data['channel_id'])
self.user_id = int(data['user_id'])
self.emoji = emoji
self.event_type = event_type
self.member = None
try:
self.guild_id: Optional[int] = int(data["guild_id"])
self.guild_id = int(data['guild_id'])
except KeyError:
self.guild_id: Optional[int] = None
self.guild_id = None
class RawReactionClearEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_clear` event.
@ -211,17 +184,16 @@ class RawReactionClearEvent(_RawReprMixin):
The guild ID where the reactions got cleared.
"""
__slots__ = ("message_id", "channel_id", "guild_id")
__slots__ = ('message_id', 'channel_id', 'guild_id')
def __init__(self, data: ReactionClearEvent) -> None:
self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data["channel_id"])
def __init__(self, data):
self.message_id = int(data['message_id'])
self.channel_id = int(data['channel_id'])
try:
self.guild_id: Optional[int] = int(data["guild_id"])
self.guild_id = int(data['guild_id'])
except KeyError:
self.guild_id: Optional[int] = None
self.guild_id = None
class RawReactionClearEmojiEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_clear_emoji` event.
@ -240,18 +212,17 @@ class RawReactionClearEmojiEvent(_RawReprMixin):
The custom or unicode emoji being removed.
"""
__slots__ = ("message_id", "channel_id", "guild_id", "emoji")
__slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji')
def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None:
self.emoji: PartialEmoji = emoji
self.message_id: int = int(data["message_id"])
self.channel_id: int = int(data["channel_id"])
def __init__(self, data, emoji):
self.emoji = emoji
self.message_id = int(data['message_id'])
self.channel_id = int(data['channel_id'])
try:
self.guild_id: Optional[int] = int(data["guild_id"])
self.guild_id = int(data['guild_id'])
except KeyError:
self.guild_id: Optional[int] = None
self.guild_id = None
class RawIntegrationDeleteEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_integration_delete` event.
@ -268,46 +239,13 @@ class RawIntegrationDeleteEvent(_RawReprMixin):
The guild ID where the integration got deleted.
"""
__slots__ = ("integration_id", "application_id", "guild_id")
__slots__ = ('integration_id', 'application_id', 'guild_id')
def __init__(self, data: IntegrationDeleteEvent) -> None:
self.integration_id: int = int(data["id"])
self.guild_id: int = int(data["guild_id"])
def __init__(self, data):
self.integration_id = int(data['id'])
self.guild_id = int(data['guild_id'])
try:
self.application_id: Optional[int] = int(data["application_id"])
self.application_id = int(data['application_id'])
except KeyError:
self.application_id: Optional[int] = None
class RawTypingEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_typing` event.
.. versionadded:: 2.0
Attributes
-----------
channel_id: :class:`int`
The channel ID where the typing originated from.
user_id: :class:`int`
The ID of the user that started typing.
when: :class:`datetime.datetime`
When the typing started as an aware datetime in UTC.
guild_id: Optional[:class:`int`]
The guild ID where the typing originated from, if applicable.
member: Optional[:class:`Member`]
The member who started typing. Only available if the member started typing in a guild.
"""
__slots__ = ("channel_id", "user_id", "when", "guild_id", "member")
def __init__(self, data: TypingEvent) -> None:
self.channel_id: int = int(data["channel_id"])
self.user_id: int = int(data["user_id"])
self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get("timestamp"), tz=datetime.timezone.utc)
self.member: Optional[Member] = None
try:
self.guild_id: Optional[int] = int(data["guild_id"])
except KeyError:
self.guild_id: Optional[int] = None
self.application_id = None

View File

@ -22,20 +22,11 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, TYPE_CHECKING, Union, Optional
from .iterators import ReactionIterator
__all__ = ("Reaction",)
if TYPE_CHECKING:
from .types.message import Reaction as ReactionPayload
from .message import Message
from .partial_emoji import PartialEmoji
from .emoji import Emoji
from .abc import Snowflake
__all__ = (
'Reaction',
)
class Reaction:
"""Represents a reaction to a message.
@ -74,40 +65,37 @@ class Reaction:
message: :class:`Message`
Message this reaction is for.
"""
__slots__ = ('message', 'count', 'emoji', 'me')
__slots__ = ("message", "count", "emoji", "me")
def __init__(self, *, message, data, emoji=None):
self.message = message
self.emoji = emoji or message._state.get_reaction_emoji(data['emoji'])
self.count = data.get('count', 1)
self.me = data.get('me')
def __init__(
self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None
):
self.message: Message = message
self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data["emoji"])
self.count: int = data.get("count", 1)
self.me: bool = data.get("me")
# TODO: typeguard
def is_custom_emoji(self) -> bool:
@property
def custom_emoji(self):
""":class:`bool`: If this is a custom emoji."""
return not isinstance(self.emoji, str)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other):
return isinstance(other, self.__class__) and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool:
def __ne__(self, other):
if isinstance(other, self.__class__):
return other.emoji != self.emoji
return True
def __hash__(self) -> int:
def __hash__(self):
return hash(self.emoji)
def __str__(self) -> str:
def __str__(self):
return str(self.emoji)
def __repr__(self) -> str:
return f"<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>"
def __repr__(self):
return f'<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>'
async def remove(self, user: Snowflake) -> None:
async def remove(self, user):
"""|coro|
Remove the reaction by the provided :class:`User` from the message.
@ -135,7 +123,7 @@ class Reaction:
await self.message.remove_reaction(self.emoji, user)
async def clear(self) -> None:
async def clear(self):
"""|coro|
Clears this reaction from the message.
@ -157,7 +145,7 @@ class Reaction:
"""
await self.message.clear_reaction(self.emoji)
def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> ReactionIterator:
def users(self, limit=None, after=None):
"""Returns an :class:`AsyncIterator` representing the users that have reacted to the message.
The ``after`` parameter must represent a member
@ -181,11 +169,11 @@ class Reaction:
Parameters
------------
limit: Optional[:class:`int`]
limit: :class:`int`
The maximum number of results to return.
If not provided, returns all the users who
reacted to the message.
after: Optional[:class:`abc.Snowflake`]
after: :class:`abc.Snowflake`
For pagination, reactions are sorted by member.
Raises
@ -202,8 +190,8 @@ class Reaction:
if the member has left the guild.
"""
if not isinstance(self.emoji, str):
emoji = f"{self.emoji.name}:{self.emoji.id}"
if self.custom_emoji:
emoji = f'{self.emoji.name}:{self.emoji.id}'
else:
emoji = self.emoji

View File

@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, TypeVar, Union, overload, TYPE_CHECKING
from typing import Any, List, Optional, TypeVar, Union, overload, TYPE_CHECKING
from .permissions import Permissions
from .errors import InvalidArgument
@ -32,8 +32,8 @@ from .mixins import Hashable
from .utils import snowflake_time, _get_as_snowflake, MISSING
__all__ = (
"RoleTags",
"Role",
'RoleTags',
'Role',
)
if TYPE_CHECKING:
@ -42,7 +42,6 @@ if TYPE_CHECKING:
Role as RolePayload,
RoleTags as RoleTagPayload,
)
from .types.guild import RolePositionUpdate
from .guild import Guild
from .member import Member
from .state import ConnectionState
@ -68,19 +67,19 @@ class RoleTags:
"""
__slots__ = (
"bot_id",
"integration_id",
"_premium_subscriber",
'bot_id',
'integration_id',
'_premium_subscriber',
)
def __init__(self, data: RoleTagPayload):
self.bot_id: Optional[int] = _get_as_snowflake(data, "bot_id")
self.integration_id: Optional[int] = _get_as_snowflake(data, "integration_id")
self.bot_id: Optional[int] = _get_as_snowflake(data, 'bot_id')
self.integration_id: Optional[int] = _get_as_snowflake(data, 'integration_id')
# NOTE: The API returns "null" for this if it's valid, which corresponds to None.
# This is different from other fields where "null" means "not there".
# So in this case, a value of None is the same as True.
# Which means we would need a different sentinel.
self._premium_subscriber: Optional[Any] = data.get("premium_subscriber", MISSING)
self._premium_subscriber: Optional[Any] = data.get('premium_subscriber', MISSING)
def is_bot_managed(self) -> bool:
""":class:`bool`: Whether the role is associated with a bot."""
@ -96,12 +95,12 @@ class RoleTags:
def __repr__(self) -> str:
return (
f"<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} "
f"premium_subscriber={self.is_premium_subscriber()}>"
f'<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} '
f'premium_subscriber={self.is_premium_subscriber()}>'
)
R = TypeVar("R", bound="Role")
R = TypeVar('R', bound='Role')
class Role(Hashable):
@ -141,14 +140,6 @@ class Role(Hashable):
Returns the role's name.
.. describe:: str(x)
Returns the role's ID.
.. describe:: int(x)
Returns the role's ID.
Attributes
----------
id: :class:`int`
@ -181,40 +172,37 @@ class Role(Hashable):
"""
__slots__ = (
"id",
"name",
"_permissions",
"_colour",
"position",
"managed",
"mentionable",
"hoist",
"guild",
"tags",
"_state",
'id',
'name',
'_permissions',
'_colour',
'position',
'managed',
'mentionable',
'hoist',
'guild',
'tags',
'_state',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload):
self.guild: Guild = guild
self._state: ConnectionState = state
self.id: int = int(data["id"])
self.id: int = int(data['id'])
self._update(data)
def __str__(self) -> str:
return self.name
def __int__(self) -> int:
return self.id
def __repr__(self) -> str:
return f"<Role id={self.id} name={self.name!r}>"
return f'<Role id={self.id} name={self.name!r}>'
def __lt__(self: R, other: R) -> bool:
if not isinstance(other, Role) or not isinstance(self, Role):
return NotImplemented
if self.guild != other.guild:
raise RuntimeError("cannot compare roles from two different guilds.")
raise RuntimeError('cannot compare roles from two different guilds.')
# the @everyone role is always the lowest role in hierarchy
guild_id = self.guild.id
@ -246,17 +234,17 @@ class Role(Hashable):
return not r
def _update(self, data: RolePayload):
self.name: str = data["name"]
self._permissions: int = int(data.get("permissions", 0))
self.position: int = data.get("position", 0)
self._colour: int = data.get("color", 0)
self.hoist: bool = data.get("hoist", False)
self.managed: bool = data.get("managed", False)
self.mentionable: bool = data.get("mentionable", False)
self.name: str = data['name']
self._permissions: int = int(data.get('permissions', 0))
self.position: int = data.get('position', 0)
self._colour: int = data.get('color', 0)
self.hoist: bool = data.get('hoist', False)
self.managed: bool = data.get('managed', False)
self.mentionable: bool = data.get('mentionable', False)
self.tags: Optional[RoleTags]
try:
self.tags = RoleTags(data["tags"])
self.tags = RoleTags(data['tags'])
except KeyError:
self.tags = None
@ -316,7 +304,7 @@ class Role(Hashable):
@property
def mention(self) -> str:
""":class:`str`: Returns a string that allows you to mention a role."""
return f"<@&{self.id}>"
return f'<@&{self.id}>'
@property
def members(self) -> List[Member]:
@ -348,21 +336,28 @@ class Role(Hashable):
else:
roles.append(self.id)
payload: List[RolePositionUpdate] = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
await http.move_role_position(self.guild.id, payload, reason=reason)
@overload
async def edit(
self,
*,
name: str = MISSING,
permissions: Permissions = MISSING,
colour: Union[Colour, int] = MISSING,
color: Union[Colour, int] = MISSING,
hoist: bool = MISSING,
mentionable: bool = MISSING,
position: int = MISSING,
reason: Optional[str] = MISSING,
) -> Optional[Role]:
reason: Optional[str] = ...,
name: str = ...,
permissions: Permissions = ...,
colour: Union[Colour, int] = ...,
hoist: bool = ...,
mentionable: bool = ...,
position: int = ...,
) -> None:
...
@overload
async def edit(self) -> None:
...
async def edit(self, *, reason=None, **fields) -> None:
"""|coro|
Edits the role.
@ -375,9 +370,6 @@ class Role(Hashable):
.. versionchanged:: 1.4
Can now pass ``int`` to ``colour`` keyword-only parameter.
.. versionchanged:: 2.0
Edits are no longer in-place, the newly edited role is returned instead.
Parameters
-----------
name: :class:`str`
@ -405,39 +397,31 @@ class Role(Hashable):
InvalidArgument
An invalid position was given or the default
role was asked to be moved.
Returns
--------
:class:`Role`
The newly edited role.
"""
if position is not MISSING:
position = fields.get('position')
if position is not None:
await self._move(position, reason=reason)
self.position = position
payload: Dict[str, Any] = {}
if color is not MISSING:
colour = color
try:
colour = fields['colour']
except KeyError:
colour = fields.get('color', self.colour)
if colour is not MISSING:
if isinstance(colour, int):
payload["color"] = colour
else:
payload["color"] = colour.value
if isinstance(colour, int):
colour = Colour(value=colour)
if name is not MISSING:
payload["name"] = name
if permissions is not MISSING:
payload["permissions"] = permissions.value
if hoist is not MISSING:
payload["hoist"] = hoist
if mentionable is not MISSING:
payload["mentionable"] = mentionable
payload = {
'name': fields.get('name', self.name),
'permissions': str(fields.get('permissions', self.permissions).value),
'color': colour.value,
'hoist': fields.get('hoist', self.hoist),
'mentionable': fields.get('mentionable', self.mentionable),
}
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
return Role(guild=self.guild, data=data, state=self._state)
self._update(data)
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|

View File

@ -22,9 +22,8 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import asyncio
import itertools
import logging
import aiohttp
@ -35,30 +34,22 @@ from .backoff import ExponentialBackoff
from .gateway import *
from .errors import (
ClientException,
InvalidArgument,
HTTPException,
GatewayNotFound,
ConnectionClosed,
PrivilegedIntentsRequired,
)
from . import utils
from .enums import Status
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar
if TYPE_CHECKING:
from .gateway import DiscordWebSocket
from .activity import BaseActivity
from .enums import Status
EI = TypeVar("EI", bound="EventItem")
__all__ = (
"AutoShardedClient",
"ShardInfo",
'AutoShardedClient',
'ShardInfo',
)
_log = logging.getLogger(__name__)
log = logging.getLogger(__name__)
class EventType:
close = 0
@ -68,41 +59,39 @@ class EventType:
terminate = 4
clean_close = 5
class EventItem:
__slots__ = ("type", "shard", "error")
__slots__ = ('type', 'shard', 'error')
def __init__(self, etype: int, shard: Optional["Shard"], error: Optional[Exception]) -> None:
self.type: int = etype
self.shard: Optional["Shard"] = shard
self.error: Optional[Exception] = error
def __init__(self, etype, shard, error):
self.type = etype
self.shard = shard
self.error = error
def __lt__(self: EI, other: EI) -> bool:
def __lt__(self, other):
if not isinstance(other, EventItem):
return NotImplemented
return self.type < other.type
def __eq__(self: EI, other: EI) -> bool:
def __eq__(self, other):
if not isinstance(other, EventItem):
return NotImplemented
return self.type == other.type
def __hash__(self) -> int:
def __hash__(self):
return hash(self.type)
class Shard:
def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None:
self.ws: DiscordWebSocket = ws
self._client: Client = client
self._dispatch: Callable[..., None] = client.dispatch
self._queue_put: Callable[[EventItem], None] = queue_put
self.loop: asyncio.AbstractEventLoop = self._client.loop
self._disconnect: bool = False
def __init__(self, ws, client, queue_put):
self.ws = ws
self._client = client
self._dispatch = client.dispatch
self._queue_put = queue_put
self.loop = self._client.loop
self._disconnect = False
self._reconnect = client._reconnect
self._backoff: ExponentialBackoff = ExponentialBackoff()
self._task: Optional[asyncio.Task] = None
self._handled_exceptions: Tuple[Type[Exception], ...] = (
self._backoff = ExponentialBackoff()
self._task = None
self._handled_exceptions = (
OSError,
HTTPException,
GatewayNotFound,
@ -112,28 +101,27 @@ class Shard:
)
@property
def id(self) -> int:
# DiscordWebSocket.shard_id is set in the from_client classmethod
return self.ws.shard_id # type: ignore
def id(self):
return self.ws.shard_id
def launch(self) -> None:
def launch(self):
self._task = self.loop.create_task(self.worker())
def _cancel_task(self) -> None:
def _cancel_task(self):
if self._task is not None and not self._task.done():
self._task.cancel()
async def close(self) -> None:
async def close(self):
self._cancel_task()
await self.ws.close(code=1000)
async def disconnect(self) -> None:
async def disconnect(self):
await self.close()
self._dispatch("shard_disconnect", self.id)
self._dispatch('shard_disconnect', self.id)
async def _handle_disconnect(self, e: Exception) -> None:
self._dispatch("disconnect")
self._dispatch("shard_disconnect", self.id)
async def _handle_disconnect(self, e):
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
if not self._reconnect:
self._queue_put(EventItem(EventType.close, self, e))
return
@ -156,11 +144,11 @@ class Shard:
return
retry = self._backoff.delay()
_log.error("Attempting a reconnect for shard ID %s in %.2fs", self.id, retry, exc_info=e)
log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
await asyncio.sleep(retry)
self._queue_put(EventItem(EventType.reconnect, self, e))
async def worker(self) -> None:
async def worker(self):
while not self._client.is_closed():
try:
await self.ws.poll_event()
@ -177,19 +165,14 @@ class Shard:
self._queue_put(EventItem(EventType.terminate, self, e))
break
async def reidentify(self, exc: ReconnectWebSocket) -> None:
async def reidentify(self, exc):
self._cancel_task()
self._dispatch("disconnect")
self._dispatch("shard_disconnect", self.id)
_log.info("Got a request to %s the websocket at Shard ID %s.", exc.op, self.id)
self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id)
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
try:
coro = DiscordWebSocket.from_client(
self._client,
resume=exc.resume,
shard_id=self.id,
session=self.ws.session_id,
sequence=self.ws.sequence,
)
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
session=self.ws.session_id, sequence=self.ws.sequence)
self.ws = await asyncio.wait_for(coro, timeout=60.0)
except self._handled_exceptions as e:
await self._handle_disconnect(e)
@ -200,7 +183,7 @@ class Shard:
else:
self.launch()
async def reconnect(self) -> None:
async def reconnect(self):
self._cancel_task()
try:
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
@ -214,7 +197,6 @@ class Shard:
else:
self.launch()
class ShardInfo:
"""A class that gives information and control over a specific shard.
@ -231,18 +213,18 @@ class ShardInfo:
The shard count for this cluster. If this is ``None`` then the bot has not started yet.
"""
__slots__ = ("_parent", "id", "shard_count")
__slots__ = ('_parent', 'id', 'shard_count')
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None:
self._parent: Shard = parent
self.id: int = parent.id
self.shard_count: Optional[int] = shard_count
def __init__(self, parent, shard_count):
self._parent = parent
self.id = parent.id
self.shard_count = shard_count
def is_closed(self) -> bool:
def is_closed(self):
""":class:`bool`: Whether the shard connection is currently closed."""
return not self._parent.ws.open
async def disconnect(self) -> None:
async def disconnect(self):
"""|coro|
Disconnects a shard. When this is called, the shard connection will no
@ -255,7 +237,7 @@ class ShardInfo:
await self._parent.disconnect()
async def reconnect(self) -> None:
async def reconnect(self):
"""|coro|
Disconnects and then connects the shard again.
@ -264,7 +246,7 @@ class ShardInfo:
await self._parent.disconnect()
await self._parent.reconnect()
async def connect(self) -> None:
async def connect(self):
"""|coro|
Connects a shard. If the shard is already connected this does nothing.
@ -275,11 +257,11 @@ class ShardInfo:
await self._parent.reconnect()
@property
def latency(self) -> float:
def latency(self):
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard."""
return self._parent.ws.latency
def is_ws_ratelimited(self) -> bool:
def is_ws_ratelimited(self):
""":class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members
@ -289,7 +271,6 @@ class ShardInfo:
"""
return self._parent.ws.is_ratelimited()
class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
@ -316,20 +297,16 @@ class AutoShardedClient(Client):
shard_ids: Optional[List[:class:`int`]]
An optional list of shard_ids to launch the shards with.
"""
if TYPE_CHECKING:
_connection: AutoShardedConnectionState
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop("shard_id", None)
self.shard_ids: Optional[List[int]] = kwargs.pop("shard_ids", None)
def __init__(self, *args, loop=None, **kwargs):
kwargs.pop('shard_id', None)
self.shard_ids = kwargs.pop('shard_ids', None)
super().__init__(*args, loop=loop, **kwargs)
if self.shard_ids is not None:
if self.shard_count is None:
raise ClientException("When passing manual shard_ids, you must provide a shard_count.")
raise ClientException('When passing manual shard_ids, you must provide a shard_count.')
elif not isinstance(self.shard_ids, (list, tuple)):
raise ClientException("shard_ids parameter must be a list or a tuple.")
raise ClientException('shard_ids parameter must be a list or a tuple.')
# instead of a single websocket, we have multiple
# the key is the shard_id
@ -338,24 +315,18 @@ class AutoShardedClient(Client):
self._connection._get_client = lambda: self
self.__queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket:
def _get_websocket(self, guild_id=None, *, shard_id=None):
if shard_id is None:
# guild_id won't be None if shard_id is None and shard_count won't be None here
shard_id = (guild_id >> 22) % self.shard_count # type: ignore
shard_id = (guild_id >> 22) % self.shard_count
return self.__shards[shard_id].ws
def _get_state(self, **options: Any) -> AutoShardedConnectionState:
return AutoShardedConnectionState(
dispatch=self.dispatch,
handlers=self._handlers,
hooks=self._hooks,
http=self.http,
loop=self.loop,
**options,
)
def _get_state(self, **options):
return AutoShardedConnectionState(dispatch=self.dispatch,
handlers=self._handlers,
hooks=self._hooks, http=self.http, loop=self.loop, **options)
@property
def latency(self) -> float:
def latency(self):
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This operates similarly to :meth:`Client.latency` except it uses the average
@ -363,18 +334,18 @@ class AutoShardedClient(Client):
:attr:`latencies` property. Returns ``nan`` if there are no shards ready.
"""
if not self.__shards:
return float("nan")
return float('nan')
return sum(latency for _, latency in self.latencies) / len(self.__shards)
@property
def latencies(self) -> List[Tuple[int, float]]:
def latencies(self):
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This returns a list of tuples with elements ``(shard_id, latency)``.
"""
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
def get_shard(self, shard_id: int) -> Optional[ShardInfo]:
def get_shard(self, shard_id):
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
try:
parent = self.__shards[shard_id]
@ -384,16 +355,16 @@ class AutoShardedClient(Client):
return ShardInfo(parent, self.shard_count)
@property
def shards(self) -> Dict[int, ShardInfo]:
def shards(self):
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None:
async def launch_shard(self, gateway, shard_id, *, initial=False):
try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0)
except Exception:
_log.exception("Failed to connect for shard_id: %s. Retrying...", shard_id)
log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id)
@ -401,7 +372,7 @@ class AutoShardedClient(Client):
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait)
ret.launch()
async def launch_shards(self) -> None:
async def launch_shards(self):
if self.shard_count is None:
self.shard_count, gateway = await self.http.get_bot_gateway()
else:
@ -418,7 +389,7 @@ class AutoShardedClient(Client):
self._connection.shards_launched.set()
async def connect(self, *, reconnect: bool = True) -> None:
async def connect(self, *, reconnect=True):
self._reconnect = reconnect
await self.launch_shards()
@ -442,7 +413,7 @@ class AutoShardedClient(Client):
elif item.type == EventType.clean_close:
return
async def close(self) -> None:
async def close(self):
"""|coro|
Closes the connection to Discord.
@ -454,7 +425,7 @@ class AutoShardedClient(Client):
for vc in self.voice_clients:
try:
await vc.disconnect(force=True)
await vc.disconnect()
except Exception:
pass
@ -465,13 +436,7 @@ class AutoShardedClient(Client):
await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
async def change_presence(
self,
*,
activity: Optional[BaseActivity] = None,
status: Optional[Status] = None,
shard_id: int = None,
) -> None:
async def change_presence(self, *, activity=None, status=None, afk=False, shard_id=None):
"""|coro|
Changes the client's presence.
@ -481,9 +446,6 @@ class AutoShardedClient(Client):
game = discord.Game("with the API")
await client.change_presence(status=discord.Status.idle, activity=game)
.. versionchanged:: 2.0
Removed the ``afk`` keyword-only parameter.
Parameters
----------
activity: Optional[:class:`BaseActivity`]
@ -491,6 +453,10 @@ class AutoShardedClient(Client):
status: Optional[:class:`Status`]
Indicates what status to change to. If ``None``, then
:attr:`Status.online` is used.
afk: :class:`bool`
Indicates if you are going AFK. This allows the discord
client to know how to handle push notifications better
for you in case you are actually idle and not lying.
shard_id: Optional[:class:`int`]
The shard_id to change the presence to. If not specified
or ``None``, then it will change the presence of every
@ -503,23 +469,23 @@ class AutoShardedClient(Client):
"""
if status is None:
status_value = "online"
status = 'online'
status_enum = Status.online
elif status is Status.offline:
status_value = "invisible"
status = 'invisible'
status_enum = Status.offline
else:
status_enum = status
status_value = str(status)
status = str(status)
if shard_id is None:
for shard in self.__shards.values():
await shard.ws.change_presence(activity=activity, status=status_value)
await shard.ws.change_presence(activity=activity, status=status, afk=afk)
guilds = self._connection.guilds
else:
shard = self.__shards[shard_id]
await shard.ws.change_presence(activity=activity, status=status_value)
await shard.ws.change_presence(activity=activity, status=status, afk=afk)
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]
activities = () if activity is None else (activity,)
@ -528,11 +494,10 @@ class AutoShardedClient(Client):
if me is None:
continue
# Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...]
me.activities = activities # type: ignore
me.activities = activities
me.status = status_enum
def is_ws_ratelimited(self) -> bool:
def is_ws_ratelimited(self):
""":class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members

View File

@ -31,7 +31,9 @@ from .mixins import Hashable
from .errors import InvalidArgument
from .enums import StagePrivacyLevel, try_enum
__all__ = ("StageInstance",)
__all__ = (
'StageInstance',
)
if TYPE_CHECKING:
from .types.channel import StageInstance as StageInstancePayload
@ -49,7 +51,7 @@ class StageInstance(Hashable):
.. describe:: x == y
Checks if two stage instances are equal.
Checks if two stagea instances are equal.
.. describe:: x != y
@ -59,10 +61,6 @@ class StageInstance(Hashable):
Returns the stage instance's hash.
.. describe:: int(x)
Returns the stage instance's ID.
Attributes
-----------
id: :class:`int`
@ -76,18 +74,18 @@ class StageInstance(Hashable):
privacy_level: :class:`StagePrivacyLevel`
The privacy level of the stage instance.
discoverable_disabled: :class:`bool`
Whether discoverability for the stage instance is disabled.
Whether the stage instance is discoverable.
"""
__slots__ = (
"_state",
"id",
"guild",
"channel_id",
"topic",
"privacy_level",
"discoverable_disabled",
"_cs_channel",
'_state',
'id',
'guild',
'channel_id',
'topic',
'privacy_level',
'discoverable_disabled',
'_cs_channel',
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
@ -96,27 +94,24 @@ class StageInstance(Hashable):
self._update(data)
def _update(self, data: StageInstancePayload):
self.id: int = int(data["id"])
self.channel_id: int = int(data["channel_id"])
self.topic: str = data["topic"]
self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data["privacy_level"])
self.discoverable_disabled: bool = data.get("discoverable_disabled", False)
self.id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.topic: str = data['topic']
self.privacy_level = try_enum(StagePrivacyLevel, data['privacy_level'])
self.discoverable_disabled = data['discoverable_disabled']
def __repr__(self) -> str:
return f"<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>"
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>'
@cached_slot_property("_cs_channel")
@cached_slot_property('_cs_channel')
def channel(self) -> Optional[StageChannel]:
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
# the returned channel will always be a StageChannel or None
return self._state.get_channel(self.channel_id) # type: ignore
"""Optional[:class:`StageChannel`: The guild that stage instance is running in."""
return self._state.get_channel(self.channel_id)
def is_public(self) -> bool:
return self.privacy_level is StagePrivacyLevel.public
async def edit(
self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None
) -> None:
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING) -> None:
"""|coro|
Edits the stage instance.
@ -130,8 +125,6 @@ class StageInstance(Hashable):
The stage instance's new topic.
privacy_level: :class:`StagePrivacyLevel`
The stage instance's new privacy level.
reason: :class:`str`
The reason the stage instance was edited. Shows up on the audit log.
Raises
------
@ -146,18 +139,18 @@ class StageInstance(Hashable):
payload = {}
if topic is not MISSING:
payload["topic"] = topic
payload['topic'] = topic
if privacy_level is not MISSING:
if not isinstance(privacy_level, StagePrivacyLevel):
raise InvalidArgument("privacy_level field must be of type PrivacyLevel")
raise InvalidArgument('privacy_level field must be of type PrivacyLevel')
payload["privacy_level"] = privacy_level.value
payload['privacy_level'] = privacy_level.value
if payload:
await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason)
await self._state.http.edit_stage_instance(self.channel_id, **payload)
async def delete(self, *, reason: Optional[str] = None) -> None:
async def delete(self) -> None:
"""|coro|
Deletes the stage instance.
@ -165,11 +158,6 @@ class StageInstance(Hashable):
You must have the :attr:`~Permissions.manage_channels` permission to
use this.
Parameters
-----------
reason: :class:`str`
The reason the stage instance was deleted. Shows up on the audit log.
Raises
------
Forbidden
@ -177,4 +165,4 @@ class StageInstance(Hashable):
HTTPException
Deleting the stage instance failed.
"""
await self._state.http.delete_stage_instance(self.channel_id, reason=reason)
await self._state.http.delete_stage_instance(self.channel_id)

File diff suppressed because it is too large Load Diff

View File

@ -23,226 +23,24 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Literal, TYPE_CHECKING, List, Optional, Tuple, Type, Union
import unicodedata
from typing import TYPE_CHECKING, List, Optional
from .mixins import Hashable
from .asset import Asset, AssetMixin
from .utils import cached_slot_property, find, snowflake_time, get, MISSING
from .errors import InvalidData
from .enums import StickerType, StickerFormatType, try_enum
from .asset import Asset
from .utils import snowflake_time
from .enums import StickerType, try_enum
__all__ = (
"StickerPack",
"StickerItem",
"Sticker",
"StandardSticker",
"GuildSticker",
'Sticker',
)
if TYPE_CHECKING:
import datetime
from .state import ConnectionState
from .user import User
from .guild import Guild
from .types.sticker import (
StickerPack as StickerPackPayload,
StickerItem as StickerItemPayload,
Sticker as StickerPayload,
StandardSticker as StandardStickerPayload,
GuildSticker as GuildStickerPayload,
ListPremiumStickerPacks as ListPremiumStickerPacksPayload,
EditGuildSticker,
)
from .types.message import Sticker as StickerPayload
class StickerPack(Hashable):
"""Represents a sticker pack.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker pack.
.. describe:: hash(x)
Returns the hash of the sticker pack.
.. describe:: int(x)
Returns the ID of the sticker pack.
.. describe:: x == y
Checks if the sticker pack is equal to another sticker pack.
.. describe:: x != y
Checks if the sticker pack is not equal to another sticker pack.
Attributes
-----------
name: :class:`str`
The name of the sticker pack.
description: :class:`str`
The description of the sticker pack.
id: :class:`int`
The id of the sticker pack.
stickers: List[:class:`StandardSticker`]
The stickers of this sticker pack.
sku_id: :class:`int`
The SKU ID of the sticker pack.
cover_sticker_id: :class:`int`
The ID of the sticker used for the cover of the sticker pack.
cover_sticker: :class:`StandardSticker`
The sticker used for the cover of the sticker pack.
"""
__slots__ = (
"_state",
"id",
"stickers",
"name",
"sku_id",
"cover_sticker_id",
"cover_sticker",
"description",
"_banner",
)
def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None:
self._state: ConnectionState = state
self._from_data(data)
def _from_data(self, data: StickerPackPayload) -> None:
self.id: int = int(data["id"])
stickers = data["stickers"]
self.stickers: List[StandardSticker] = [
StandardSticker(state=self._state, data=sticker) for sticker in stickers
]
self.name: str = data["name"]
self.sku_id: int = int(data["sku_id"])
self.cover_sticker_id: int = int(data["cover_sticker_id"])
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
self.description: str = data["description"]
self._banner: int = int(data["banner_asset_id"])
@property
def banner(self) -> Asset:
""":class:`Asset`: The banner asset of the sticker pack."""
return Asset._from_sticker_banner(self._state, self._banner)
def __repr__(self) -> str:
return f"<StickerPack id={self.id} name={self.name!r} description={self.description!r}>"
def __str__(self) -> str:
return self.name
class _StickerTag(Hashable, AssetMixin):
__slots__ = ()
id: int
format: StickerFormatType
async def read(self) -> bytes:
"""|coro|
Retrieves the content of this sticker as a :class:`bytes` object.
.. note::
Stickers that use the :attr:`StickerFormatType.lottie` format cannot be read.
Raises
------
HTTPException
Downloading the asset failed.
NotFound
The asset was deleted.
TypeError
The sticker is a lottie type.
Returns
-------
:class:`bytes`
The content of the asset.
"""
if self.format is StickerFormatType.lottie:
raise TypeError('Cannot read stickers of format "lottie".')
return await super().read()
class StickerItem(_StickerTag):
"""Represents a sticker item.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker item.
.. describe:: x == y
Checks if the sticker item is equal to another sticker item.
.. describe:: x != y
Checks if the sticker item is not equal to another sticker item.
Attributes
-----------
name: :class:`str`
The sticker's name.
id: :class:`int`
The id of the sticker.
format: :class:`StickerFormatType`
The format for the sticker's image.
url: :class:`str`
The URL for the sticker's image.
"""
__slots__ = ("_state", "name", "id", "format", "url")
def __init__(self, *, state: ConnectionState, data: StickerItemPayload):
self._state: ConnectionState = state
self.name: str = data["name"]
self.id: int = int(data["id"])
self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"])
self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
def __repr__(self) -> str:
return f"<StickerItem id={self.id} name={self.name!r} format={self.format}>"
def __str__(self) -> str:
return self.name
async def fetch(self) -> Union[Sticker, StandardSticker, GuildSticker]:
"""|coro|
Attempts to retrieve the full sticker data of the sticker item.
Raises
--------
HTTPException
Retrieving the sticker failed.
Returns
--------
Union[:class:`StandardSticker`, :class:`GuildSticker`]
The retrieved sticker.
"""
data: StickerPayload = await self._state.http.get_sticker(self.id)
cls, _ = _sticker_factory(data["type"]) # type: ignore
return cls(state=self._state, data=data)
class Sticker(_StickerTag):
class Sticker(Hashable):
"""Represents a sticker.
.. versionadded:: 1.6
@ -271,27 +69,30 @@ class Sticker(_StickerTag):
The description of the sticker.
pack_id: :class:`int`
The id of the sticker's pack.
format: :class:`StickerFormatType`
format: :class:`StickerType`
The format for the sticker's image.
url: :class:`str`
The URL for the sticker's image.
tags: List[:class:`str`]
A list of tags for the sticker.
"""
__slots__ = ("_state", "id", "name", "description", "format", "url")
__slots__ = ('_state', 'id', 'name', 'description', 'pack_id', 'format', '_image', 'tags')
def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None:
def __init__(self, *, state: ConnectionState, data: StickerPayload):
self._state: ConnectionState = state
self._from_data(data)
self.id: int = int(data['id'])
self.name: str = data['name']
self.description: str = data['description']
self.pack_id: int = int(data['pack_id'])
self.format: StickerType = try_enum(StickerType, data['format_type'])
self._image: str = data['asset']
def _from_data(self, data: StickerPayload) -> None:
self.id: int = int(data["id"])
self.name: str = data["name"]
self.description: str = data["description"]
self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"])
self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}"
try:
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')]
except KeyError:
self.tags = []
def __repr__(self) -> str:
return f"<Sticker id={self.id} name={self.name!r}>"
return f'<{self.__class__.__name__} id={self.id} name={self.name!r}>'
def __str__(self) -> str:
return self.name
@ -301,235 +102,19 @@ class Sticker(_StickerTag):
""":class:`datetime.datetime`: Returns the sticker's creation time in UTC."""
return snowflake_time(self.id)
@property
def image(self) -> Optional[Asset]:
"""Returns an :class:`Asset` for the sticker's image.
class StandardSticker(Sticker):
"""Represents a sticker that is found in a standard sticker pack.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker.
.. describe:: x == y
Checks if the sticker is equal to another sticker.
.. describe:: x != y
Checks if the sticker is not equal to another sticker.
Attributes
----------
name: :class:`str`
The sticker's name.
id: :class:`int`
The id of the sticker.
description: :class:`str`
The description of the sticker.
pack_id: :class:`int`
The id of the sticker's pack.
format: :class:`StickerFormatType`
The format for the sticker's image.
tags: List[:class:`str`]
A list of tags for the sticker.
sort_value: :class:`int`
The sticker's sort order within its pack.
"""
__slots__ = ("sort_value", "pack_id", "type", "tags")
def _from_data(self, data: StandardStickerPayload) -> None:
super()._from_data(data)
self.sort_value: int = data["sort_value"]
self.pack_id: int = int(data["pack_id"])
self.type: StickerType = StickerType.standard
try:
self.tags: List[str] = [tag.strip() for tag in data["tags"].split(",")]
except KeyError:
self.tags = []
def __repr__(self) -> str:
return f"<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>"
async def pack(self) -> StickerPack:
"""|coro|
Retrieves the sticker pack that this sticker belongs to.
Raises
--------
InvalidData
The corresponding sticker pack was not found.
HTTPException
Retrieving the sticker pack failed.
.. note::
This will return ``None`` if the format is ``StickerType.lottie``.
Returns
--------
:class:`StickerPack`
The retrieved sticker pack.
"""
data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs()
packs = data["sticker_packs"]
pack = find(lambda d: int(d["id"]) == self.pack_id, packs)
if pack:
return StickerPack(state=self._state, data=pack)
raise InvalidData(f"Could not find corresponding sticker pack for {self!r}")
class GuildSticker(Sticker):
"""Represents a sticker that belongs to a guild.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker.
.. describe:: x == y
Checks if the sticker is equal to another sticker.
.. describe:: x != y
Checks if the sticker is not equal to another sticker.
Attributes
----------
name: :class:`str`
The sticker's name.
id: :class:`int`
The id of the sticker.
description: :class:`str`
The description of the sticker.
format: :class:`StickerFormatType`
The format for the sticker's image.
available: :class:`bool`
Whether this sticker is available for use.
guild_id: :class:`int`
The ID of the guild that this sticker is from.
user: Optional[:class:`User`]
The user that created this sticker. This can only be retrieved using :meth:`Guild.fetch_sticker` and
having the :attr:`~Permissions.manage_emojis_and_stickers` permission.
emoji: :class:`str`
The name of a unicode emoji that represents this sticker.
"""
__slots__ = ("available", "guild_id", "user", "emoji", "type", "_cs_guild")
def _from_data(self, data: GuildStickerPayload) -> None:
super()._from_data(data)
self.available: bool = data["available"]
self.guild_id: int = int(data["guild_id"])
user = data.get("user")
self.user: Optional[User] = self._state.store_user(user) if user else None
self.emoji: str = data["tags"]
self.type: StickerType = StickerType.guild
def __repr__(self) -> str:
return f"<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>"
@cached_slot_property("_cs_guild")
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that this sticker is from.
Could be ``None`` if the bot is not in the guild.
.. versionadded:: 2.0
"""
return self._state._get_guild(self.guild_id)
async def edit(
self,
*,
name: str = MISSING,
description: str = MISSING,
emoji: str = MISSING,
reason: Optional[str] = None,
) -> GuildSticker:
"""|coro|
Edits a :class:`GuildSticker` for the guild.
Parameters
-----------
name: :class:`str`
The sticker's new name. Must be at least 2 characters.
description: Optional[:class:`str`]
The sticker's new description. Can be ``None``.
emoji: :class:`str`
The name of a unicode emoji that represents the sticker's expression.
reason: :class:`str`
The reason for editing this sticker. Shows up on the audit log.
Raises
-------
Forbidden
You are not allowed to edit stickers.
HTTPException
An error occurred editing the sticker.
Returns
--------
:class:`GuildSticker`
The newly modified sticker.
Optional[:class:`Asset`]
The resulting CDN asset.
"""
payload: EditGuildSticker = {}
if self.format is StickerType.lottie:
return None
if name is not MISSING:
payload["name"] = name
if description is not MISSING:
payload["description"] = description
if emoji is not MISSING:
try:
emoji = unicodedata.name(emoji)
except TypeError:
pass
else:
emoji = emoji.replace(" ", "_")
payload["tags"] = emoji
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason)
return GuildSticker(state=self._state, data=data)
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the custom :class:`Sticker` from the guild.
You must have :attr:`~Permissions.manage_emojis_and_stickers` permission to
do this.
Parameters
-----------
reason: Optional[:class:`str`]
The reason for deleting this sticker. Shows up on the audit log.
Raises
-------
Forbidden
You are not allowed to delete stickers.
HTTPException
An error occurred deleting the sticker.
"""
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason)
def _sticker_factory(
sticker_type: Literal[1, 2]
) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
value = try_enum(StickerType, sticker_type)
if value == StickerType.standard:
return StandardSticker, value
elif value == StickerType.guild:
return GuildSticker, value
else:
return Sticker, value
return Asset._from_sticker(self._state, self.id, self._image)

View File

@ -40,8 +40,8 @@ if TYPE_CHECKING:
)
__all__ = (
"Team",
"TeamMember",
'Team',
'TeamMember',
)
@ -62,26 +62,26 @@ class Team:
.. versionadded:: 1.3
"""
__slots__ = ("_state", "id", "name", "_icon", "owner_id", "members")
__slots__ = ('_state', 'id', 'name', '_icon', 'owner_id', 'members')
def __init__(self, state: ConnectionState, data: TeamPayload):
self._state: ConnectionState = state
self.id: int = int(data["id"])
self.name: str = data["name"]
self._icon: Optional[str] = data["icon"]
self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_user_id")
self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data["members"]]
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: Optional[str] = data['icon']
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_user_id')
self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data['members']]
def __repr__(self) -> str:
return f"<{self.__class__.__name__} id={self.id} name={self.name}>"
return f'<{self.__class__.__name__} id={self.id} name={self.name}>'
@property
def icon(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the team's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path="team")
return Asset._from_icon(self._state, self.id, self._icon, path='team')
@property
def owner(self) -> Optional[TeamMember]:
@ -130,16 +130,16 @@ class TeamMember(BaseUser):
The membership state of the member (e.g. invited or accepted)
"""
__slots__ = ("team", "membership_state", "permissions")
__slots__ = BaseUser.__slots__ + ('team', 'membership_state', 'permissions')
def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload):
self.team: Team = team
self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data["membership_state"])
self.permissions: List[str] = data["permissions"]
super().__init__(state=state, data=data["user"])
self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data['membership_state'])
self.permissions: List[str] = data['permissions']
super().__init__(state=state, data=data['user'])
def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} id={self.id} name={self.name!r} "
f"discriminator={self.discriminator!r} membership_state={self.membership_state!r}>"
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
f'discriminator={self.discriminator!r} membership_state={self.membership_state!r}>'
)

View File

@ -24,25 +24,24 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Optional, TYPE_CHECKING
from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
from typing import Any, Optional, TYPE_CHECKING, overload
from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data
from .enums import VoiceRegion
from .guild import Guild
__all__ = ("Template",)
__all__ = (
'Template',
)
if TYPE_CHECKING:
import datetime
from .types.template import Template as TemplatePayload
from .state import ConnectionState
from .user import User
class _FriendlyHttpAttributeErrorHelper:
__slots__ = ()
def __getattr__(self, attr):
raise AttributeError("PartialTemplateState does not support http methods.")
raise AttributeError('PartialTemplateState does not support http methods.')
class _PartialTemplateState:
@ -75,14 +74,11 @@ class _PartialTemplateState:
def _get_message(self, id):
return None
def _get_guild(self, id):
return self.__state._get_guild(id)
async def query_members(self, **kwargs: Any):
async def query_members(self, **kwargs):
return []
def __getattr__(self, attr):
raise AttributeError(f"PartialTemplateState does not support {attr!r}.")
raise AttributeError(f'PartialTemplateState does not support {attr!r}.')
class Template:
@ -116,55 +112,53 @@ class Template:
"""
__slots__ = (
"code",
"uses",
"name",
"description",
"creator",
"created_at",
"updated_at",
"source_guild",
"is_dirty",
"_state",
'code',
'uses',
'name',
'description',
'creator',
'created_at',
'updated_at',
'source_guild',
'is_dirty',
'_state',
)
def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None:
def __init__(self, *, state, data: TemplatePayload):
self._state = state
self._store(data)
def _store(self, data: TemplatePayload) -> None:
self.code: str = data["code"]
self.uses: int = data["usage_count"]
self.name: str = data["name"]
self.description: Optional[str] = data["description"]
creator_data = data.get("creator")
self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data)
def _store(self, data: TemplatePayload):
self.code = data['code']
self.uses = data['usage_count']
self.name = data['name']
self.description = data['description']
creator_data = data.get('creator')
self.creator = None if creator_data is None else self._state.store_user(creator_data)
self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at"))
self.updated_at: Optional[datetime.datetime] = parse_time(data.get("updated_at"))
self.created_at = parse_time(data.get('created_at'))
self.updated_at = parse_time(data.get('updated_at'))
guild_id = int(data["source_guild_id"])
guild: Optional[Guild] = self._state._get_guild(guild_id)
id = _get_as_snowflake(data, 'source_guild_id')
self.source_guild: Guild
if guild is None:
source_serialised = data["serialized_source_guild"]
source_serialised["id"] = guild_id
guild = self._state._get_guild(id)
if guild is None and id:
source_serialised = data['serialized_source_guild']
source_serialised['id'] = id
state = _PartialTemplateState(state=self._state)
# Guild expects a ConnectionState, we're passing a _PartialTemplateState
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
else:
self.source_guild = guild
guild = Guild(data=source_serialised, state=state)
self.is_dirty: Optional[bool] = data.get("is_dirty", None)
self.source_guild = guild
self.is_dirty = data.get('is_dirty', None)
def __repr__(self) -> str:
return (
f"<Template code={self.code!r} uses={self.uses} name={self.name!r}"
f" creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>"
f'<Template code={self.code!r} uses={self.uses} name={self.name!r}'
f' creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>'
)
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild:
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None):
"""|coro|
Creates a :class:`.Guild` using the template.
@ -204,7 +198,7 @@ class Template:
data = await self._state.http.create_from_template(self.code, name, region_value, icon)
return Guild(data=data, state=self._state)
async def sync(self) -> Template:
async def sync(self) -> None:
"""|coro|
Sync the template to the guild's current state.
@ -214,9 +208,6 @@ class Template:
.. versionadded:: 1.7
.. versionchanged:: 2.0
The template is no longer edited in-place, instead it is returned.
Raises
-------
HTTPException
@ -225,22 +216,25 @@ class Template:
You don't have permissions to edit the template.
NotFound
This template does not exist.
Returns
--------
:class:`Template`
The newly edited template.
"""
data = await self._state.http.sync_template(self.source_guild.id, self.code)
return Template(state=self._state, data=data)
self._store(data)
@overload
async def edit(
self,
*,
name: str = MISSING,
description: Optional[str] = MISSING,
) -> Template:
name: Optional[str] = ...,
description: Optional[str] = ...,
) -> None:
...
@overload
async def edit(self) -> None:
...
async def edit(self, **kwargs) -> None:
"""|coro|
Edit the template metadata.
@ -250,15 +244,12 @@ class Template:
.. versionadded:: 1.7
.. versionchanged:: 2.0
The template is no longer edited in-place, instead it is returned.
Parameters
------------
name: :class:`str`
name: Optional[:class:`str`]
The template's new name.
description: Optional[:class:`str`]
The template's new description.
The template's description.
Raises
-------
@ -268,21 +259,9 @@ class Template:
You don't have permissions to edit the template.
NotFound
This template does not exist.
Returns
--------
:class:`Template`
The newly edited template.
"""
payload = {}
if name is not MISSING:
payload["name"] = name
if description is not MISSING:
payload["description"] = description
data = await self._state.http.edit_template(self.source_guild.id, self.code, payload)
return Template(state=self._state, data=data)
data = await self._state.http.edit_template(self.source_guild.id, self.code, kwargs)
self._store(data)
async def delete(self) -> None:
"""|coro|
@ -308,7 +287,7 @@ class Template:
@property
def url(self) -> str:
""":class:`str`: The template url.
.. versionadded:: 2.0
"""
return f"https://discord.new/{self.code}"
return f'https://discord.new/{self.code}'

View File

@ -23,7 +23,6 @@ DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING
import time
import asyncio
@ -35,8 +34,8 @@ from .errors import ClientException
from .utils import MISSING, parse_time, _get_as_snowflake
__all__ = (
"Thread",
"ThreadMember",
'Thread',
'ThreadMember',
)
if TYPE_CHECKING:
@ -46,11 +45,10 @@ if TYPE_CHECKING:
ThreadMetadata,
ThreadArchiveDuration,
)
from .types.snowflake import SnowflakeList
from .guild import Guild
from .channel import TextChannel, CategoryChannel
from .channel import TextChannel
from .member import Member
from .message import Message, PartialMessage
from .message import Message
from .abc import Snowflake, SnowflakeTime
from .role import Role
from .permissions import Permissions
@ -74,10 +72,6 @@ class Thread(Messageable, Hashable):
Returns the thread's hash.
.. describe:: int(x)
Returns the thread's ID.
.. describe:: str(x)
Returns the thread's name.
@ -115,9 +109,6 @@ class Thread(Messageable, Hashable):
Whether the thread is archived.
locked: :class:`bool`
Whether the thread is locked.
invitable: :class:`bool`
Whether non-moderators can add other non-moderators to this thread.
This is always ``True`` for public threads.
archiver_id: Optional[:class:`int`]
The user's ID that archived this thread.
auto_archive_duration: :class:`int`
@ -128,29 +119,28 @@ class Thread(Messageable, Hashable):
"""
__slots__ = (
"name",
"id",
"guild",
"_type",
"_state",
"_members",
"owner_id",
"parent_id",
"last_message_id",
"message_count",
"member_count",
"slowmode_delay",
"me",
"locked",
"archived",
"invitable",
"archiver_id",
"auto_archive_duration",
"archive_timestamp",
'name',
'id',
'guild',
'_type',
'_state',
'_members',
'owner_id',
'parent_id',
'last_message_id',
'message_count',
'member_count',
'slowmode_delay',
'me',
'locked',
'archived',
'archiver_id',
'auto_archive_duration',
'archive_timestamp',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
self._state: ConnectionState = state
def __init__(self, *, guild: Guild, data: ThreadPayload):
self._state: ConnectionState = guild._state
self.guild = guild
self._members: Dict[int, ThreadMember] = {}
self._from_data(data)
@ -160,83 +150,58 @@ class Thread(Messageable, Hashable):
def __repr__(self) -> str:
return (
f"<Thread id={self.id!r} name={self.name!r} parent={self.parent}"
f" owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>"
f'<Thread id={self.id!r} name={self.name!r} parent={self.parent}'
f' owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>'
)
def __str__(self) -> str:
return self.name
def _from_data(self, data: ThreadPayload):
self.id = int(data["id"])
self.parent_id = int(data["parent_id"])
self.owner_id = int(data["owner_id"])
self.name = data["name"]
self._type = try_enum(ChannelType, data["type"])
self.last_message_id = _get_as_snowflake(data, "last_message_id")
self.slowmode_delay = data.get("rate_limit_per_user", 0)
self.message_count = data["message_count"]
self.member_count = data["member_count"]
self._unroll_metadata(data["thread_metadata"])
self.id = int(data['id'])
self.parent_id = int(data['parent_id'])
self.owner_id = int(data['owner_id'])
self.name = data['name']
self._type = try_enum(ChannelType, data['type'])
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self._unroll_metadata(data['thread_metadata'])
try:
member = data["member"]
member = data['member']
except KeyError:
self.me = None
else:
self.me = ThreadMember(self, member)
def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data["archived"]
self.archiver_id = _get_as_snowflake(data, "archiver_id")
self.auto_archive_duration = data["auto_archive_duration"]
self.archive_timestamp = parse_time(data["archive_timestamp"])
self.locked = data.get("locked", False)
self.invitable = data.get("invitable", True)
self.archived = data['archived']
self.archiver_id = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp'])
self.locked = data.get('locked', False)
def _update(self, data):
try:
self.name = data["name"]
self.name = data['name']
except KeyError:
pass
self.slowmode_delay = data.get("rate_limit_per_user", 0)
try:
self._unroll_metadata(data["thread_metadata"])
self._unroll_metadata(data['thread_metadata'])
except KeyError:
pass
@property
def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return self._type
@property
def parent(self) -> Optional[TextChannel]:
"""Optional[:class:`TextChannel`]: The parent channel this thread belongs to."""
return self.guild.get_channel(self.parent_id) # type: ignore
return self.guild.get_channel(self.parent_id)
@property
def owner(self) -> Optional[Member]:
"""Optional[:class:`Member`]: The member this thread belongs to."""
return self.guild.get_member(self.owner_id)
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the thread."""
return f"<#{self.id}>"
@property
def members(self) -> List[ThreadMember]:
"""List[:class:`ThreadMember`]: A list of thread members in this thread.
This requires :attr:`Intents.members` to be properly filled. Most of the time however,
this data is not provided by the gateway and a call to :meth:`fetch_members` is
needed.
"""
return list(self._members.values())
@property
def last_message(self) -> Optional[Message]:
"""Fetches the last message from this channel in cache.
@ -258,26 +223,6 @@ class Thread(Messageable, Hashable):
"""
return self._state._get_message(self.last_message_id) if self.last_message_id else None
@property
def category(self) -> Optional[CategoryChannel]:
"""The category channel the parent channel belongs to, if applicable.
Raises
-------
ClientException
The parent channel was not cached and returned ``None``.
Returns
-------
Optional[:class:`CategoryChannel`]
The parent channel's category.
"""
parent = self.parent
if parent is None:
raise ClientException("Parent channel not found")
return parent.category
@property
def category_id(self) -> Optional[int]:
"""The category channel ID the parent channel belongs to, if applicable.
@ -295,7 +240,7 @@ class Thread(Messageable, Hashable):
parent = self.parent
if parent is None:
raise ClientException("Parent channel not found")
raise ClientException('Parent channel not found')
return parent.category_id
def is_private(self) -> bool:
@ -314,15 +259,6 @@ class Thread(Messageable, Hashable):
"""
return self._type is ChannelType.news_thread
def is_nsfw(self) -> bool:
""":class:`bool`: Whether the thread is NSFW or not.
An NSFW thread is a thread that has a parent that is an NSFW channel,
i.e. :meth:`.TextChannel.is_nsfw` is ``True``.
"""
parent = self.parent
return parent is not None and parent.is_nsfw()
def permissions_for(self, obj: Union[Member, Role], /) -> Permissions:
"""Handles permission resolution for the :class:`~discord.Member`
or :class:`~discord.Role`.
@ -352,7 +288,7 @@ class Thread(Messageable, Hashable):
parent = self.parent
if parent is None:
raise ClientException("Parent channel not found")
raise ClientException('Parent channel not found')
return parent.permissions_for(obj)
async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
@ -402,15 +338,15 @@ class Thread(Messageable, Hashable):
return
if len(messages) > 100:
raise ClientException("Can only bulk delete messages up to 100 messages")
raise ClientException('Can only bulk delete messages up to 100 messages')
message_ids: SnowflakeList = [m.id for m in messages]
message_ids = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
async def purge(
self,
*,
limit: Optional[int] = 100,
limit: int = 100,
check: Callable[[Message], bool] = MISSING,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
@ -530,10 +466,9 @@ class Thread(Messageable, Hashable):
name: str = MISSING,
archived: bool = MISSING,
locked: bool = MISSING,
invitable: bool = MISSING,
slowmode_delay: int = MISSING,
auto_archive_duration: ThreadArchiveDuration = MISSING,
) -> Thread:
):
"""|coro|
Edits the thread.
@ -553,11 +488,8 @@ class Thread(Messageable, Hashable):
Whether to archive the thread or not.
locked: :class:`bool`
Whether to lock the thread or not.
invitable: :class:`bool`
Whether non-moderators can add other non-moderators to this thread.
Only available for private threads.
auto_archive_duration: :class:`int`
The new duration in minutes before a thread is automatically archived for inactivity.
The new duration to auto archive threads for inactivity.
Must be one of ``60``, ``1440``, ``4320``, or ``10080``.
slowmode_delay: :class:`int`
Specifies the slowmode rate limit for user in this thread, in seconds.
@ -569,37 +501,30 @@ class Thread(Messageable, Hashable):
You do not have permissions to edit the thread.
HTTPException
Editing the thread failed.
Returns
--------
:class:`Thread`
The newly edited thread.
"""
payload = {}
if name is not MISSING:
payload["name"] = str(name)
payload['name'] = str(name)
if archived is not MISSING:
payload["archived"] = archived
payload['archived'] = archived
if auto_archive_duration is not MISSING:
payload["auto_archive_duration"] = auto_archive_duration
payload['auto_archive_duration'] = auto_archive_duration
if locked is not MISSING:
payload["locked"] = locked
if invitable is not MISSING:
payload["invitable"] = invitable
payload['locked'] = locked
if slowmode_delay is not MISSING:
payload["rate_limit_per_user"] = slowmode_delay
payload['rate_limit_per_user'] = slowmode_delay
data = await self._state.http.edit_channel(self.id, **payload)
# The data payload will always be a Thread payload
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
await self._state.http.edit_channel(self.id, **payload)
async def join(self):
"""|coro|
Joins this thread.
You must have :attr:`~Permissions.send_messages_in_threads` to join a thread.
If the thread is private, :attr:`~Permissions.manage_threads` is also needed.
You must have :attr:`~Permissions.send_messages` and :attr:`~Permissions.use_threads`
to join a public thread. If the thread is private then :attr:`~Permissions.send_messages`
and either :attr:`~Permissions.use_private_threads` or :attr:`~Permissions.manage_messages`
is required to join the thread.
Raises
-------
@ -667,28 +592,6 @@ class Thread(Messageable, Hashable):
"""
await self._state.http.remove_user_from_thread(self.id, user.id)
async def fetch_members(self) -> List[ThreadMember]:
"""|coro|
Retrieves all :class:`ThreadMember` that are in this thread.
This requires :attr:`Intents.members` to get information about members
other than yourself.
Raises
-------
HTTPException
Retrieving the members failed.
Returns
--------
List[:class:`ThreadMember`]
All thread members in the thread.
"""
members = await self._state.http.get_thread_members(self.id)
return [ThreadMember(parent=self, data=data) for data in members]
async def delete(self):
"""|coro|
@ -705,29 +608,6 @@ class Thread(Messageable, Hashable):
"""
await self._state.http.delete_channel(self.id)
def get_partial_message(self, message_id: int, /) -> PartialMessage:
"""Creates a :class:`PartialMessage` from the message ID.
This is useful if you want to work with a message and only have its ID without
doing an unnecessary API call.
.. versionadded:: 2.0
Parameters
------------
message_id: :class:`int`
The message ID to create a partial message for.
Returns
---------
:class:`PartialMessage`
The partial message.
"""
from .message import PartialMessage
return PartialMessage(channel=self, id=message_id)
def _add_member(self, member: ThreadMember) -> None:
self._members[member.id] = member
@ -752,10 +632,6 @@ class ThreadMember(Hashable):
Returns the thread member's hash.
.. describe:: int(x)
Returns the thread member's ID.
.. describe:: str(x)
Returns the thread member's name.
@ -773,12 +649,12 @@ class ThreadMember(Hashable):
"""
__slots__ = (
"id",
"thread_id",
"joined_at",
"flags",
"_state",
"parent",
'id',
'thread_id',
'joined_at',
'flags',
'_state',
'parent',
)
def __init__(self, parent: Thread, data: ThreadMemberPayload):
@ -787,60 +663,24 @@ class ThreadMember(Hashable):
self._from_data(data)
def __repr__(self) -> str:
return f"<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>"
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
def _from_data(self, data: ThreadMemberPayload):
try:
self.id = int(data["user_id"])
self.id = int(data['user_id'])
except KeyError:
assert self._state.self_id is not None
self.id = self._state.self_id
try:
self.thread_id = int(data["id"])
self.thread_id = int(data['id'])
except KeyError:
self.thread_id = self.parent.id
self.joined_at = parse_time(data["join_timestamp"])
self.flags = data["flags"]
self.joined_at = parse_time(data['join_timestamp'])
self.flags = data['flags']
@property
def thread(self) -> Thread:
""":class:`Thread`: The thread this member belongs to."""
return self.parent
async def fetch_member(self) -> Member:
"""|coro|
Retrieves a :class:`Member` from the ThreadMember object.
.. note::
This method is an API call. If you have :attr:`Intents.members` and member cache enabled, consider :meth:`get_member` instead.
Raises
-------
Forbidden
You do not have access to the guild.
HTTPException
Fetching the member failed.
Returns
--------
:class:`Member`
The member.
"""
return await self.thread.guild.fetch_member(self.id)
def get_member(self) -> Optional[Member]:
"""
Get the :class:`Member` from cache for the ThreadMember object.
Returns
--------
Optional[:class:`Member`]
The member or ``None`` if not found.
"""
return self.thread.guild.get_member(self.id)

View File

@ -29,7 +29,7 @@ from .user import PartialUser
from .snowflake import Snowflake
StatusType = Literal["idle", "dnd", "online", "offline"]
StatusType = Literal['idle', 'dnd', 'online', 'offline']
class PartialPresenceUpdate(TypedDict):
@ -41,9 +41,9 @@ class PartialPresenceUpdate(TypedDict):
class ClientStatus(TypedDict, total=False):
desktop: str
mobile: str
web: str
desktop: bool
mobile: bool
web: bool
class ActivityTimestamps(TypedDict, total=False):

View File

@ -30,7 +30,6 @@ from .user import User
from .team import Team
from .snowflake import Snowflake
class BaseAppInfo(TypedDict):
id: Snowflake
name: str
@ -39,7 +38,6 @@ class BaseAppInfo(TypedDict):
summary: str
description: str
class _AppInfoOptional(TypedDict, total=False):
team: Team
guild_id: Snowflake
@ -50,14 +48,12 @@ class _AppInfoOptional(TypedDict, total=False):
hook: bool
max_participants: int
class AppInfo(BaseAppInfo, _AppInfoOptional):
rpc_origins: List[str]
owner: User
bot_public: bool
bot_require_code_grant: bool
class _PartialAppInfoOptional(TypedDict, total=False):
rpc_origins: List[str]
cover_image: str
@ -65,8 +61,6 @@ class _PartialAppInfoOptional(TypedDict, total=False):
terms_of_service_url: str
privacy_policy_url: str
max_participants: int
flags: int
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
pass

View File

@ -32,7 +32,6 @@ from .user import User
from .snowflake import Snowflake
from .role import Role
from .channel import ChannelType, VideoQualityMode, PermissionOverwrite
from .threads import Thread
AuditLogEvent = Literal[
1,
@ -70,54 +69,35 @@ AuditLogEvent = Literal[
80,
81,
82,
83,
84,
85,
90,
91,
92,
110,
111,
112,
]
class _AuditLogChange_Str(TypedDict):
key: Literal[
"name",
"description",
"preferred_locale",
"vanity_url_code",
"topic",
"code",
"allow",
"deny",
"permissions",
"tags",
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions'
]
new_value: str
old_value: str
class _AuditLogChange_AssetHash(TypedDict):
key: Literal["icon_hash", "splash_hash", "discovery_splash_hash", "banner_hash", "avatar_hash", "asset"]
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash']
new_value: str
old_value: str
class _AuditLogChange_Snowflake(TypedDict):
key: Literal[
"id",
"owner_id",
"afk_channel_id",
"rules_channel_id",
"public_updates_channel_id",
"widget_channel_id",
"system_channel_id",
"application_id",
"channel_id",
"inviter_id",
"guild_id",
'id',
'owner_id',
'afk_channel_id',
'rules_channel_id',
'public_updates_channel_id',
'widget_channel_id',
'system_channel_id',
'application_id',
'channel_id',
'inviter_id',
]
new_value: Snowflake
old_value: Snowflake
@ -125,20 +105,17 @@ class _AuditLogChange_Snowflake(TypedDict):
class _AuditLogChange_Bool(TypedDict):
key: Literal[
"widget_enabled",
"nsfw",
"hoist",
"mentionable",
"temporary",
"deaf",
"mute",
"nick",
"enabled_emoticons",
"region",
"rtc_region",
"available",
"archived",
"locked",
'widget_enabled',
'nsfw',
'hoist',
'mentionable',
'temporary',
'deaf',
'mute',
'nick',
'enabled_emoticons',
'region',
'rtc_region',
]
new_value: bool
old_value: bool
@ -146,72 +123,70 @@ class _AuditLogChange_Bool(TypedDict):
class _AuditLogChange_Int(TypedDict):
key: Literal[
"afk_timeout",
"prune_delete_days",
"position",
"bitrate",
"rate_limit_per_user",
"color",
"max_uses",
"max_age",
"user_limit",
"auto_archive_duration",
"default_auto_archive_duration",
'afk_timeout',
'prune_delete_days',
'position',
'bitrate',
'rate_limit_per_user',
'color',
'max_uses',
'max_age',
'user_limit',
]
new_value: int
old_value: int
class _AuditLogChange_ListRole(TypedDict):
key: Literal["$add", "$remove"]
key: Literal['$add', '$remove']
new_value: List[Role]
old_value: List[Role]
class _AuditLogChange_MFALevel(TypedDict):
key: Literal["mfa_level"]
key: Literal['mfa_level']
new_value: MFALevel
old_value: MFALevel
class _AuditLogChange_VerificationLevel(TypedDict):
key: Literal["verification_level"]
key: Literal['verification_level']
new_value: VerificationLevel
old_value: VerificationLevel
class _AuditLogChange_ExplicitContentFilter(TypedDict):
key: Literal["explicit_content_filter"]
key: Literal['explicit_content_filter']
new_value: ExplicitContentFilterLevel
old_value: ExplicitContentFilterLevel
class _AuditLogChange_DefaultMessageNotificationLevel(TypedDict):
key: Literal["default_message_notifications"]
key: Literal['default_message_notifications']
new_value: DefaultMessageNotificationLevel
old_value: DefaultMessageNotificationLevel
class _AuditLogChange_ChannelType(TypedDict):
key: Literal["type"]
key: Literal['type']
new_value: ChannelType
old_value: ChannelType
class _AuditLogChange_IntegrationExpireBehaviour(TypedDict):
key: Literal["expire_behavior"]
key: Literal['expire_behavior']
new_value: IntegrationExpireBehavior
old_value: IntegrationExpireBehavior
class _AuditLogChange_VideoQualityMode(TypedDict):
key: Literal["video_quality_mode"]
key: Literal['video_quality_mode']
new_value: VideoQualityMode
old_value: VideoQualityMode
class _AuditLogChange_Overwrites(TypedDict):
key: Literal["permission_overwrites"]
key: Literal['permission_overwrites']
new_value: List[PermissionOverwrite]
old_value: List[PermissionOverwrite]
@ -241,7 +216,7 @@ class AuditEntryInfo(TypedDict):
message_id: Snowflake
count: str
id: Snowflake
type: Literal["0", "1"]
type: Literal['0', '1']
role_name: str
@ -263,4 +238,3 @@ class AuditLog(TypedDict):
users: List[User]
audit_log_entries: List[AuditLogEntry]
integrations: List[PartialIntegration]
threads: List[Thread]

View File

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
from typing import List, Literal, Optional, TypedDict, Union
from .user import PartialUser
from .snowflake import Snowflake
from .threads import ThreadMetadata, ThreadMember, ThreadArchiveDuration
from .threads import ThreadMetadata, ThreadMember
OverwriteType = Literal[0, 1]
@ -63,7 +63,6 @@ class _TextChannelOptional(TypedDict, total=False):
last_message_id: Optional[Snowflake]
last_pin_timestamp: str
rate_limit_per_user: int
default_auto_archive_duration: ThreadArchiveDuration
class TextChannel(_BaseGuildChannel, _TextChannelOptional):
@ -79,13 +78,13 @@ VideoQualityMode = Literal[1, 2]
class _VoiceChannelOptional(TypedDict, total=False):
rtc_region: Optional[str]
bitrate: int
user_limit: int
video_quality_mode: VideoQualityMode
class VoiceChannel(_BaseGuildChannel, _VoiceChannelOptional):
type: Literal[2]
bitrate: int
user_limit: int
class CategoryChannel(_BaseGuildChannel):
@ -98,13 +97,13 @@ class StoreChannel(_BaseGuildChannel):
class _StageChannelOptional(TypedDict, total=False):
rtc_region: Optional[str]
bitrate: int
user_limit: int
topic: str
class StageChannel(_BaseGuildChannel, _StageChannelOptional):
type: Literal[13]
bitrate: int
user_limit: int
class _ThreadChannelOptional(TypedDict, total=False):
@ -116,7 +115,7 @@ class _ThreadChannelOptional(TypedDict, total=False):
class ThreadChannel(_BaseChannel, _ThreadChannelOptional):
type: Literal[10, 11, 12]
type: Literal[11, 12]
guild_id: Snowflake
parent_id: Snowflake
owner_id: Snowflake

View File

@ -53,7 +53,6 @@ class _SelectMenuOptional(TypedDict, total=False):
placeholder: str
min_values: int
max_values: int
disabled: bool
class _SelectOptionsOptional(TypedDict, total=False):

View File

@ -24,60 +24,49 @@ DEALINGS IN THE SOFTWARE.
from typing import List, Literal, TypedDict
class _EmbedFooterOptional(TypedDict, total=False):
icon_url: str
proxy_icon_url: str
class EmbedFooter(_EmbedFooterOptional):
text: str
class _EmbedFieldOptional(TypedDict, total=False):
inline: bool
class EmbedField(_EmbedFieldOptional):
name: str
value: str
class EmbedThumbnail(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedVideo(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedImage(TypedDict, total=False):
url: str
proxy_url: str
height: int
width: int
class EmbedProvider(TypedDict, total=False):
name: str
url: str
class EmbedAuthor(TypedDict, total=False):
name: str
url: str
icon_url: str
proxy_icon_url: str
EmbedType = Literal["rich", "image", "video", "gifv", "article", "link"]
EmbedType = Literal['rich', 'image', 'video', 'gifv', 'article', 'link']
class Embed(TypedDict, total=False):
title: str

View File

@ -75,28 +75,21 @@ VerificationLevel = Literal[0, 1, 2, 3, 4]
NSFWLevel = Literal[0, 1, 2, 3]
PremiumTier = Literal[0, 1, 2, 3]
GuildFeature = Literal[
"ANIMATED_ICON",
"BANNER",
"COMMERCE",
"COMMUNITY",
"DISCOVERABLE",
"FEATURABLE",
"INVITE_SPLASH",
"MEMBER_VERIFICATION_GATE_ENABLED",
"MONETIZATION_ENABLED",
"MORE_EMOJI",
"MORE_STICKERS",
"NEWS",
"PARTNERED",
"PREVIEW_ENABLED",
"PRIVATE_THREADS",
"SEVEN_DAY_THREAD_ARCHIVE",
"THREE_DAY_THREAD_ARCHIVE",
"TICKETED_EVENTS_ENABLED",
"VANITY_URL",
"VERIFIED",
"VIP_REGIONS",
"WELCOME_SCREEN_ENABLED",
'INVITE_SPLASH',
'VIP_REGIONS',
'VANITY_URL',
'VERIFIED',
'PARTNERED',
'COMMUNITY',
'COMMERCE',
'NEWS',
'DISCOVERABLE',
'FEATURABLE',
'ANIMATED_ICON',
'BANNER',
'WELCOME_SCREEN_ENABLED',
'MEMBER_VERIFICATION_GATE_ENABLED',
'PREVIEW_ENABLED',
]
@ -159,10 +152,8 @@ class ChannelPositionUpdate(TypedDict):
lock_permissions: Optional[bool]
parent_id: Optional[Snowflake]
class _RolePositionRequired(TypedDict):
id: Snowflake
class RolePositionUpdate(_RolePositionRequired, total=False):
position: Optional[Snowflake]

View File

@ -56,7 +56,7 @@ class PartialIntegration(TypedDict):
account: IntegrationAccount
IntegrationType = Literal["twitch", "youtube", "discord"]
IntegrationType = Literal['twitch', 'youtube', 'discord']
class BaseIntegration(PartialIntegration):
@ -69,7 +69,7 @@ class BaseIntegration(PartialIntegration):
class StreamIntegration(BaseIntegration):
role_id: Optional[Snowflake]
role_id: Snowflake
enable_emoticons: bool
subscriber_count: int
revoked: bool

View File

@ -37,12 +37,8 @@ if TYPE_CHECKING:
from .message import AllowedMentions, Message
ApplicationCommandType = Literal[1, 2, 3]
class _ApplicationCommandOptional(TypedDict, total=False):
options: List[ApplicationCommandOption]
type: ApplicationCommandType
class ApplicationCommand(_ApplicationCommandOptional):
@ -57,7 +53,7 @@ class _ApplicationCommandOptionOptional(TypedDict, total=False):
options: List[ApplicationCommandOption]
ApplicationCommandOptionType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
ApplicationCommandOptionType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9]
class ApplicationCommandOption(_ApplicationCommandOptionOptional):
@ -97,48 +93,16 @@ class GuildApplicationCommandPermissions(PartialGuildApplicationCommandPermissio
InteractionType = Literal[1, 2, 3]
class _ApplicationCommandInteractionDataOption(TypedDict):
name: str
class _ApplicationCommandInteractionDataOptionSubcommand(_ApplicationCommandInteractionDataOption):
type: Literal[1, 2]
class _ApplicationCommandInteractionDataOptionOptional(TypedDict, total=False):
value: ApplicationCommandOptionType
options: List[ApplicationCommandInteractionDataOption]
class _ApplicationCommandInteractionDataOptionString(_ApplicationCommandInteractionDataOption):
type: Literal[3]
value: str
class _ApplicationCommandInteractionDataOptionInteger(_ApplicationCommandInteractionDataOption):
type: Literal[4]
value: int
class _ApplicationCommandInteractionDataOptionBoolean(_ApplicationCommandInteractionDataOption):
type: Literal[5]
value: bool
class _ApplicationCommandInteractionDataOptionSnowflake(_ApplicationCommandInteractionDataOption):
type: Literal[6, 7, 8, 9]
value: Snowflake
class _ApplicationCommandInteractionDataOptionNumber(_ApplicationCommandInteractionDataOption):
type: Literal[10]
value: float
ApplicationCommandInteractionDataOption = Union[
_ApplicationCommandInteractionDataOptionString,
_ApplicationCommandInteractionDataOptionInteger,
_ApplicationCommandInteractionDataOptionSubcommand,
_ApplicationCommandInteractionDataOptionBoolean,
_ApplicationCommandInteractionDataOptionSnowflake,
_ApplicationCommandInteractionDataOptionNumber,
]
class ApplicationCommandInteractionDataOption(
_ApplicationCommandInteractionDataOptionOptional
):
name: str
type: ApplicationCommandOptionType
class ApplicationCommandResolvedPartialChannel(TypedDict):
@ -158,8 +122,6 @@ class ApplicationCommandInteractionDataResolved(TypedDict, total=False):
class _ApplicationCommandInteractionDataOptional(TypedDict, total=False):
options: List[ApplicationCommandInteractionDataOption]
resolved: ApplicationCommandInteractionDataResolved
target_id: Snowflake
type: ApplicationCommandType
class ApplicationCommandInteractionData(_ApplicationCommandInteractionDataOptional):
@ -176,11 +138,8 @@ class ComponentInteractionData(_ComponentInteractionDataOptional):
component_type: ComponentType
InteractionData = Union[ApplicationCommandInteractionData, ComponentInteractionData]
class _InteractionOptional(TypedDict, total=False):
data: InteractionData
data: Union[ApplicationCommandInteractionData, ComponentInteractionData]
guild_id: Snowflake
channel_id: Snowflake
member: Member
@ -223,12 +182,8 @@ class MessageInteraction(TypedDict):
user: User
class _EditApplicationCommandOptional(TypedDict, total=False):
class EditApplicationCommand(TypedDict):
name: str
description: str
options: Optional[List[ApplicationCommandOption]]
type: ApplicationCommandType
default_permission: bool
class EditApplicationCommand(_EditApplicationCommandOptional):
name: str

View File

@ -39,25 +39,8 @@ class PartialMember(TypedDict):
class Member(PartialMember, total=False):
avatar: str
user: User
nick: str
premium_since: str
pending: bool
permissions: str
class _OptionalMemberWithUser(PartialMember, total=False):
avatar: str
nick: str
premium_since: str
pending: bool
permissions: str
class MemberWithUser(_OptionalMemberWithUser):
user: User
class UserWithMember(User, total=False):
member: _OptionalMemberWithUser

View File

@ -26,14 +26,13 @@ from __future__ import annotations
from typing import List, Literal, Optional, TypedDict, Union
from .snowflake import Snowflake, SnowflakeList
from .member import Member, UserWithMember
from .member import Member
from .user import User
from .emoji import PartialEmoji
from .embed import Embed
from .channel import ChannelType
from .components import Component
from .interactions import MessageInteraction
from .sticker import StickerItem
class ChannelMention(TypedDict):
@ -53,7 +52,6 @@ class _AttachmentOptional(TypedDict, total=False):
height: Optional[int]
width: Optional[int]
content_type: str
ephemeral: bool
spoiler: bool
@ -91,6 +89,22 @@ class MessageReference(TypedDict, total=False):
fail_if_not_exists: bool
class _StickerOptional(TypedDict, total=False):
tags: str
StickerFormatType = Literal[1, 2, 3]
class Sticker(_StickerOptional):
id: Snowflake
pack_id: Snowflake
name: str
description: str
asset: str
format_type: StickerFormatType
class _MessageOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
@ -103,7 +117,7 @@ class _MessageOptional(TypedDict, total=False):
application_id: Snowflake
message_reference: MessageReference
flags: int
sticker_items: List[StickerItem]
stickers: List[Sticker]
referenced_message: Optional[Message]
interaction: MessageInteraction
components: List[Component]
@ -121,7 +135,7 @@ class Message(_MessageOptional):
edited_timestamp: Optional[str]
tts: bool
mention_everyone: bool
mentions: List[UserWithMember]
mentions: List[User]
mention_roles: SnowflakeList
attachments: List[Attachment]
embeds: List[Embed]
@ -129,7 +143,7 @@ class Message(_MessageOptional):
type: MessageType
AllowedMentionType = Literal["roles", "users", "everyone"]
AllowedMentionType = Literal['roles', 'users', 'everyone']
class AllowedMentions(TypedDict):

View File

@ -1,98 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import TypedDict, List
from .snowflake import Snowflake
from .member import Member
from .emoji import PartialEmoji
class _MessageEventOptional(TypedDict, total=False):
guild_id: Snowflake
class MessageDeleteEvent(_MessageEventOptional):
id: Snowflake
channel_id: Snowflake
class BulkMessageDeleteEvent(_MessageEventOptional):
ids: List[Snowflake]
channel_id: Snowflake
class _ReactionActionEventOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
class MessageUpdateEvent(_MessageEventOptional):
id: Snowflake
channel_id: Snowflake
class ReactionActionEvent(_ReactionActionEventOptional):
user_id: Snowflake
channel_id: Snowflake
message_id: Snowflake
emoji: PartialEmoji
class _ReactionClearEventOptional(TypedDict, total=False):
guild_id: Snowflake
class ReactionClearEvent(_ReactionClearEventOptional):
channel_id: Snowflake
message_id: Snowflake
class _ReactionClearEmojiEventOptional(TypedDict, total=False):
guild_id: Snowflake
class ReactionClearEmojiEvent(_ReactionClearEmojiEventOptional):
channel_id: int
message_id: int
emoji: PartialEmoji
class _IntegrationDeleteEventOptional(TypedDict, total=False):
application_id: Snowflake
class IntegrationDeleteEvent(_IntegrationDeleteEventOptional):
id: Snowflake
guild_id: Snowflake
class _TypingEventOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
class TypingEvent(_TypingEventOptional):
channel_id: Snowflake
user_id: Snowflake
timestamp: int

View File

@ -1,93 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, TypedDict, Union
from .snowflake import Snowflake
from .user import User
StickerFormatType = Literal[1, 2, 3]
class StickerItem(TypedDict):
id: Snowflake
name: str
format_type: StickerFormatType
class BaseSticker(TypedDict):
id: Snowflake
name: str
description: str
tags: str
format_type: StickerFormatType
class StandardSticker(BaseSticker):
type: Literal[1]
sort_value: int
pack_id: Snowflake
class _GuildStickerOptional(TypedDict, total=False):
user: User
class GuildSticker(BaseSticker, _GuildStickerOptional):
type: Literal[2]
available: bool
guild_id: Snowflake
Sticker = Union[BaseSticker, StandardSticker, GuildSticker]
class StickerPack(TypedDict):
id: Snowflake
stickers: List[StandardSticker]
name: str
sku_id: Snowflake
cover_sticker_id: Snowflake
description: str
banner_asset_id: Snowflake
class _CreateGuildStickerOptional(TypedDict, total=False):
description: str
class CreateGuildSticker(_CreateGuildStickerOptional):
name: str
tags: str
class EditGuildSticker(TypedDict, total=False):
name: str
tags: str
description: str
class ListPremiumStickerPacks(TypedDict):
sticker_packs: List[StickerPack]

View File

@ -29,14 +29,12 @@ from typing import TypedDict, List, Optional
from .user import PartialUser
from .snowflake import Snowflake
class TeamMember(TypedDict):
user: PartialUser
membership_state: int
permissions: List[str]
team_id: Snowflake
class Team(TypedDict):
id: Snowflake
name: str

View File

@ -41,7 +41,6 @@ class ThreadMember(TypedDict):
class _ThreadMetadataOptional(TypedDict, total=False):
archiver_id: Snowflake
locked: bool
invitable: bool
class ThreadMetadata(_ThreadMetadataOptional):

View File

@ -22,16 +22,13 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import Optional, TypedDict, List, Literal
from typing import Optional, TypedDict
from .snowflake import Snowflake
from .member import MemberWithUser
SupportedModes = Literal["xsalsa20_poly1305_lite", "xsalsa20_poly1305_suffix", "xsalsa20_poly1305"]
from .member import Member
class _PartialVoiceStateOptional(TypedDict, total=False):
member: MemberWithUser
member: Member
self_stream: bool
@ -62,24 +59,3 @@ class VoiceRegion(TypedDict):
optimal: bool
deprecated: bool
custom: bool
class VoiceServerUpdate(TypedDict):
token: str
guild_id: Snowflake
endpoint: Optional[str]
class VoiceIdentify(TypedDict):
server_id: Snowflake
user_id: Snowflake
session_id: str
token: str
class VoiceReady(TypedDict):
ssrc: int
ip: str
port: int
modes: List[SupportedModes]
heartbeat_interval: int

View File

@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Any, Callable, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from typing import Callable, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
import inspect
import os
@ -35,16 +35,16 @@ from ..partial_emoji import PartialEmoji, _EmojiTag
from ..components import Button as ButtonComponent
__all__ = (
"Button",
"button",
'Button',
'button',
)
if TYPE_CHECKING:
from .view import View
from ..emoji import Emoji
B = TypeVar("B", bound="Button")
V = TypeVar("V", bound="View", covariant=True)
B = TypeVar('B', bound='Button')
V = TypeVar('V', bound='View', covariant=True)
class Button(Item[V]):
@ -60,12 +60,12 @@ class Button(Item[V]):
The ID of the button that gets received during an interaction.
If this button is for a URL, it does not have a custom ID.
url: Optional[:class:`str`]
The URL this button sends you to. This param is automatically casted to :class:`str`.
The URL this button sends you to.
disabled: :class:`bool`
Whether the button is disabled or not.
label: Optional[:class:`str`]
The label of the button, if any.
emoji: Optional[Union[:class:`.PartialEmoji`, :class:`.Emoji`, :class:`str`]]
emoji: Optional[Union[:class:`PartialEmoji`, :class:`Emoji`, :class:`str`]]
The emoji of the button, if available.
row: Optional[:class:`int`]
The relative row this button belongs to. A Discord component can only have 5
@ -76,28 +76,28 @@ class Button(Item[V]):
"""
__item_repr_attributes__: Tuple[str, ...] = (
"style",
"url",
"disabled",
"label",
"emoji",
"row",
'style',
'url',
'disabled',
'label',
'emoji',
'row',
)
def __init__(
self,
*,
style: ButtonStyle = ButtonStyle.secondary,
style: ButtonStyle,
label: Optional[str] = None,
disabled: bool = False,
custom_id: Optional[str] = None,
url: Optional[Any] = None,
url: Optional[str] = None,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
row: Optional[int] = None,
):
super().__init__()
if custom_id is not None and url is not None:
raise TypeError("cannot mix both url and custom_id with Button")
raise TypeError('cannot mix both url and custom_id with Button')
self._provided_custom_id = custom_id is not None
if url is None and custom_id is None:
@ -112,12 +112,12 @@ class Button(Item[V]):
elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial()
else:
raise TypeError(f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}")
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self._underlying = ButtonComponent._raw_construct(
type=ComponentType.button,
custom_id=custom_id,
url=str(url) if url else None,
url=url,
disabled=disabled,
label=label,
style=style,
@ -145,7 +145,7 @@ class Button(Item[V]):
@custom_id.setter
def custom_id(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError("custom_id must be None or str")
raise TypeError('custom_id must be None or str')
self._underlying.custom_id = value
@ -157,7 +157,7 @@ class Button(Item[V]):
@url.setter
def url(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError("url must be None or str")
raise TypeError('url must be None or str')
self._underlying.url = value
@property
@ -180,7 +180,7 @@ class Button(Item[V]):
@property
def emoji(self) -> Optional[PartialEmoji]:
"""Optional[:class:`.PartialEmoji`]: The emoji of the button, if available."""
"""Optional[:class:`PartialEmoji`]: The emoji of the button, if available."""
return self._underlying.emoji
@emoji.setter
@ -191,7 +191,7 @@ class Button(Item[V]):
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f"expected str, Emoji, or PartialEmoji, received {value.__class__} instead")
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
else:
self._underlying.emoji = None
@ -217,11 +217,6 @@ class Button(Item[V]):
def is_dispatchable(self) -> bool:
return self.custom_id is not None
def is_persistent(self) -> bool:
if self.style is ButtonStyle.link:
return self.url is not None
return super().is_persistent()
def refresh_component(self, button: ButtonComponent) -> None:
self._underlying = button
@ -256,13 +251,13 @@ def button(
custom_id: Optional[:class:`str`]
The ID of the button that gets received during an interaction.
It is recommended not to set this parameter to prevent conflicts.
style: :class:`.ButtonStyle`
The style of the button. Defaults to :attr:`.ButtonStyle.grey`.
style: :class:`ButtonStyle`
The style of the button. Defaults to :attr:`ButtonStyle.grey`.
disabled: :class:`bool`
Whether the button is disabled or not. Defaults to ``False``.
emoji: Optional[Union[:class:`str`, :class:`.Emoji`, :class:`.PartialEmoji`]]
The emoji of the button. This can be in string form or a :class:`.PartialEmoji`
or a full :class:`.Emoji`.
emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]]
The emoji of the button. This can be in string form or a :class:`PartialEmoji`
or a full :class:`Emoji`.
row: Optional[:class:`int`]
The relative row this button belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
@ -272,18 +267,20 @@ def button(
"""
def decorator(func: ItemCallbackType) -> ItemCallbackType:
nonlocal custom_id
if not inspect.iscoroutinefunction(func):
raise TypeError("button function must be a coroutine function")
raise TypeError('button function must be a coroutine function')
custom_id = custom_id or os.urandom(32).hex()
func.__discord_ui_model_type__ = Button
func.__discord_ui_model_kwargs__ = {
"style": style,
"custom_id": custom_id,
"url": None,
"disabled": disabled,
"label": label,
"emoji": emoji,
"row": row,
'style': style,
'custom_id': custom_id,
'url': None,
'disabled': disabled,
'label': label,
'emoji': emoji,
'row': row,
}
return func

View File

@ -28,15 +28,17 @@ from typing import Any, Callable, Coroutine, Dict, Generic, Optional, TYPE_CHECK
from ..interactions import Interaction
__all__ = ("Item",)
__all__ = (
'Item',
)
if TYPE_CHECKING:
from ..enums import ComponentType
from .view import View
from ..components import Component
I = TypeVar("I", bound="Item")
V = TypeVar("V", bound="View", covariant=True)
I = TypeVar('I', bound='Item')
V = TypeVar('V', bound='View', covariant=True)
ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]]
@ -46,12 +48,9 @@ class Item(Generic[V]):
The current UI items supported are:
- :class:`discord.ui.Button`
- :class:`discord.ui.Select`
.. versionadded:: 2.0
"""
__item_repr_attributes__: Tuple[str, ...] = ("row",)
__item_repr_attributes__: Tuple[str, ...] = ('row',)
def __init__(self):
self._view: Optional[V] = None
@ -89,8 +88,8 @@ class Item(Generic[V]):
return self._provided_custom_id
def __repr__(self) -> str:
attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__item_repr_attributes__)
return f"<{self.__class__.__name__} {attrs}>"
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__item_repr_attributes__)
return f'<{self.__class__.__name__} {attrs}>'
@property
def row(self) -> Optional[int]:
@ -103,10 +102,11 @@ class Item(Generic[V]):
elif 5 > value >= 0:
self._row = value
else:
raise ValueError("row cannot be negative or greater than or equal to 5")
raise ValueError('row cannot be negative or greater than or equal to 5')
@property
def width(self) -> int:
""":class:`int`: The width of the item."""
return 1
@property
@ -123,7 +123,7 @@ class Item(Generic[V]):
Parameters
-----------
interaction: :class:`.Interaction`
interaction: :class:`Interaction`
The interaction that triggered this UI item.
"""
pass

View File

@ -39,8 +39,8 @@ from ..components import (
)
__all__ = (
"Select",
"select",
'Select',
'select',
)
if TYPE_CHECKING:
@ -50,8 +50,8 @@ if TYPE_CHECKING:
ComponentInteractionData,
)
S = TypeVar("S", bound="Select")
V = TypeVar("V", bound="View", covariant=True)
S = TypeVar('S', bound='Select')
V = TypeVar('V', bound='View', covariant=True)
class Select(Item[V]):
@ -59,8 +59,6 @@ class Select(Item[V]):
This is usually represented as a drop down menu.
In order to get the selected items that the user has chosen, use :attr:`Select.values`.
.. versionadded:: 2.0
Parameters
@ -72,14 +70,12 @@ class Select(Item[V]):
The placeholder text that is shown if nothing is selected, if any.
min_values: :class:`int`
The minimum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 0 and 25.
Defaults to 1 and must be between 1 and 25.
max_values: :class:`int`
The maximum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25.
options: List[:class:`discord.SelectOption`]
A list of options that can be selected in this menu.
disabled: :class:`bool`
Whether the select is disabled or not.
row: Optional[:class:`int`]
The relative row this select menu belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
@ -89,11 +85,10 @@ class Select(Item[V]):
"""
__item_repr_attributes__: Tuple[str, ...] = (
"placeholder",
"min_values",
"max_values",
"options",
"disabled",
'placeholder',
'min_values',
'max_values',
'options',
)
def __init__(
@ -104,10 +99,8 @@ class Select(Item[V]):
min_values: int = 1,
max_values: int = 1,
options: List[SelectOption] = MISSING,
disabled: bool = False,
row: Optional[int] = None,
) -> None:
super().__init__()
self._selected_values: List[str] = []
self._provided_custom_id = custom_id is not MISSING
custom_id = os.urandom(16).hex() if custom_id is MISSING else custom_id
@ -119,7 +112,6 @@ class Select(Item[V]):
min_values=min_values,
max_values=max_values,
options=options,
disabled=disabled,
)
self.row = row
@ -131,7 +123,7 @@ class Select(Item[V]):
@custom_id.setter
def custom_id(self, value: str):
if not isinstance(value, str):
raise TypeError("custom_id must be None or str")
raise TypeError('custom_id must be None or str')
self._underlying.custom_id = value
@ -143,7 +135,7 @@ class Select(Item[V]):
@placeholder.setter
def placeholder(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError("placeholder must be None or str")
raise TypeError('placeholder must be None or str')
self._underlying.placeholder = value
@ -173,9 +165,9 @@ class Select(Item[V]):
@options.setter
def options(self, value: List[SelectOption]):
if not isinstance(value, list):
raise TypeError("options must be a list of SelectOption")
raise TypeError('options must be a list of SelectOption')
if not all(isinstance(obj, SelectOption) for obj in value):
raise TypeError("all list items must subclass SelectOption")
raise TypeError('all list items must subclass SelectOption')
self._underlying.options = value
@ -197,16 +189,16 @@ class Select(Item[V]):
-----------
label: :class:`str`
The label of the option. This is displayed to users.
Can only be up to 100 characters.
Can only be up to 25 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not given, defaults to the label. Can only be up to 100 characters.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 100 characters.
emoji: Optional[Union[:class:`str`, :class:`.Emoji`, :class:`.PartialEmoji`]]
Can only be up to 50 characters.
emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]]
The emoji of the option, if available. This can either be a string representing
the custom or unicode emoji or an instance of :class:`.PartialEmoji` or :class:`.Emoji`.
the custom or unicode emoji or an instance of :class:`PartialEmoji` or :class:`Emoji`.
default: :class:`bool`
Whether this option is selected by default.
@ -224,6 +216,7 @@ class Select(Item[V]):
default=default,
)
self.append_option(option)
def append_option(self, option: SelectOption):
@ -241,19 +234,10 @@ class Select(Item[V]):
"""
if len(self._underlying.options) > 25:
raise ValueError("maximum number of options already provided")
raise ValueError('maximum number of options already provided')
self._underlying.options.append(option)
@property
def disabled(self) -> bool:
""":class:`bool`: Whether the select is disabled or not."""
return self._underlying.disabled
@disabled.setter
def disabled(self, value: bool):
self._underlying.disabled = bool(value)
@property
def values(self) -> List[str]:
"""List[:class:`str`]: A list of values that have been selected by the user."""
@ -271,7 +255,7 @@ class Select(Item[V]):
def refresh_state(self, interaction: Interaction) -> None:
data: ComponentInteractionData = interaction.data # type: ignore
self._selected_values = data.get("values", [])
self._selected_values = data.get('values', [])
@classmethod
def from_component(cls: Type[S], component: SelectMenu) -> S:
@ -281,7 +265,6 @@ class Select(Item[V]):
min_values=component.min_values,
max_values=component.max_values,
options=component.options,
disabled=component.disabled,
row=None,
)
@ -300,7 +283,6 @@ def select(
min_values: int = 1,
max_values: int = 1,
options: List[SelectOption] = MISSING,
disabled: bool = False,
row: Optional[int] = None,
) -> Callable[[ItemCallbackType], ItemCallbackType]:
"""A decorator that attaches a select menu to a component.
@ -309,8 +291,6 @@ def select(
the :class:`discord.ui.View`, the :class:`discord.ui.Select` being pressed and
the :class:`discord.Interaction` you receive.
In order to get the selected items that the user has chosen within the callback
use :attr:`Select.values`.
Parameters
------------
@ -327,29 +307,26 @@ def select(
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
min_values: :class:`int`
The minimum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 0 and 25.
Defaults to 1 and must be between 1 and 25.
max_values: :class:`int`
The maximum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25.
options: List[:class:`discord.SelectOption`]
A list of options that can be selected in this menu.
disabled: :class:`bool`
Whether the select is disabled or not. Defaults to ``False``.
"""
def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func):
raise TypeError("select function must be a coroutine function")
raise TypeError('button function must be a coroutine function')
func.__discord_ui_model_type__ = Select
func.__discord_ui_model_kwargs__ = {
"placeholder": placeholder,
"custom_id": custom_id,
"row": row,
"min_values": min_values,
"max_values": max_values,
"options": options,
"disabled": disabled,
'placeholder': placeholder,
'custom_id': custom_id,
'row': row,
'min_values': min_values,
'max_values': max_values,
'options': options,
}
return func

View File

@ -33,20 +33,21 @@ import sys
import time
import os
from .item import Item, ItemCallbackType
from ..enums import ComponentType
from ..components import (
Component,
ActionRow as ActionRowComponent,
_component_factory,
Button as ButtonComponent,
SelectMenu as SelectComponent,
)
__all__ = ("View",)
__all__ = (
'View',
)
if TYPE_CHECKING:
from ..interactions import Interaction
from ..message import Message
from ..types.components import Component as ComponentPayload
from ..state import ConnectionState
@ -64,15 +65,13 @@ def _component_to_item(component: Component) -> Item:
from .button import Button
return Button.from_component(component)
if isinstance(component, SelectComponent):
from .select import Select
return Select.from_component(component)
return Item.from_component(component)
class _ViewWeights:
__slots__ = ("weights",)
__slots__ = (
'weights',
)
def __init__(self, children: List[Item]):
self.weights: List[int] = [0, 0, 0, 0, 0]
@ -88,13 +87,13 @@ class _ViewWeights:
if weight + item.width <= 5:
return index
raise ValueError("could not find open space for item")
raise ValueError('could not find open space for item')
def add_item(self, item: Item) -> None:
if item.row is not None:
total = self.weights[item.row] + item.width
if total > 5:
raise ValueError(f"item would not fit at row {item.row} ({total} > 5 width)")
raise ValueError(f'item would not fit at row {item.row} ({total} > 5 width)')
self.weights[item.row] = total
item._rendered_row = item.row
else:
@ -110,18 +109,15 @@ class _ViewWeights:
def clear(self) -> None:
self.weights = [0, 0, 0, 0, 0]
class View:
"""Represents a UI view.
This object must be inherited to create a UI within Discord.
.. versionadded:: 2.0
Parameters
-----------
timeout: Optional[:class:`float`]
Timeout in seconds from last interaction with the UI before no longer accepting input.
Timeout from last interaction with the UI before no longer accepting input.
If ``None`` then there is no timeout.
Attributes
@ -140,15 +136,15 @@ class View:
children: List[ItemCallbackType] = []
for base in reversed(cls.__mro__):
for member in base.__dict__.values():
if hasattr(member, "__discord_ui_model_type__"):
if hasattr(member, '__discord_ui_model_type__'):
children.append(member)
if len(children) > 25:
raise TypeError("View cannot have more than 25 children")
raise TypeError('View cannot have more than 25 children')
cls.__view_children_items__ = children
def __init__(self, *, timeout: Optional[float] = 180.0):
def __init__(self, timeout: Optional[float] = 180.0):
self.timeout = timeout
self.children: List[Item] = []
for func in self.__view_children_items__:
@ -160,31 +156,13 @@ class View:
self.__weights = _ViewWeights(self.children)
loop = asyncio.get_running_loop()
self.id: str = os.urandom(16).hex()
self.__cancel_callback: Optional[Callable[[View], None]] = None
self.__timeout_expiry: Optional[float] = None
self.__timeout_task: Optional[asyncio.Task[None]] = None
self.__stopped: asyncio.Future[bool] = loop.create_future()
self.id = os.urandom(16).hex()
self._cancel_callback: Optional[Callable[[View], None]] = None
self._timeout_handler: Optional[asyncio.TimerHandle] = None
self._stopped = loop.create_future()
def __repr__(self) -> str:
return f"<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>"
async def __timeout_task_impl(self) -> None:
while True:
# Guard just in case someone changes the value of the timeout at runtime
if self.timeout is None:
return
if self.__timeout_expiry is None:
return self._dispatch_timeout()
# Check if we've elapsed our currently set timeout
now = time.monotonic()
if now >= self.__timeout_expiry:
return self._dispatch_timeout()
# Wait N seconds to see if timeout data has been refreshed
await asyncio.sleep(self.__timeout_expiry - now)
return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>'
def to_components(self) -> List[Dict[str, Any]]:
def key(item: Item) -> int:
@ -199,40 +177,13 @@ class View:
components.append(
{
"type": 1,
"components": children,
'type': 1,
'components': children,
}
)
return components
@classmethod
def from_message(cls, message: Message, /, *, timeout: Optional[float] = 180.0) -> View:
"""Converts a message's components into a :class:`View`.
The :attr:`.Message.components` of a message are read-only
and separate types from those in the ``discord.ui`` namespace.
In order to modify and edit message components they must be
converted into a :class:`View` first.
Parameters
-----------
message: :class:`discord.Message`
The message with components to convert into a view.
timeout: Optional[:class:`float`]
The timeout of the converted view.
Returns
--------
:class:`View`
The converted view. This always returns a :class:`View` and not
one of its subclasses.
"""
view = View(timeout=timeout)
for component in _walk_all_components(message.components):
view.add_item(_component_to_item(component))
return view
@property
def _expires_at(self) -> Optional[float]:
if self.timeout:
@ -257,10 +208,10 @@ class View:
"""
if len(self.children) > 25:
raise ValueError("maximum number of children exceeded")
raise ValueError('maximum number of children exceeded')
if not isinstance(item, Item):
raise TypeError(f"expected Item not {item.__class__!r}")
raise TypeError(f'expected Item not {item.__class__!r}')
self.__weights.add_item(item)
@ -301,8 +252,9 @@ class View:
.. note::
If an exception occurs within the body then the check
is considered a failure and :meth:`on_error` is called.
If an exception occurs within the body then the interaction
check then :meth:`on_error` is called and it is considered
a failure.
Parameters
-----------
@ -340,46 +292,36 @@ class View:
interaction: :class:`~discord.Interaction`
The interaction that led to the failure.
"""
print(f"Ignoring exception in view {self} for item {item}:", file=sys.stderr)
print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
async def _scheduled_task(self, item: Item, interaction: Interaction):
async def _scheduled_task(self, state: Any, item: Item, interaction: Interaction):
try:
if self.timeout:
self.__timeout_expiry = time.monotonic() + self.timeout
allow = await self.interaction_check(interaction)
if not allow:
return
await item.callback(interaction)
if not interaction.response.is_done():
if not interaction.response._responded:
await interaction.response.defer()
except Exception as e:
return await self.on_error(e, item, interaction)
def _start_listening_from_store(self, store: ViewStore) -> None:
self.__cancel_callback = partial(store.remove_view)
def _start_listening(self, store: ViewStore) -> None:
self._cancel_callback = partial(store.remove_view)
if self.timeout:
loop = asyncio.get_running_loop()
if self.__timeout_task is not None:
self.__timeout_task.cancel()
self._timeout_handler = loop.call_later(self.timeout, self.dispatch_timeout)
self.__timeout_expiry = time.monotonic() + self.timeout
self.__timeout_task = loop.create_task(self.__timeout_task_impl())
def _dispatch_timeout(self):
if self.__stopped.done():
def dispatch_timeout(self):
if self._stopped.done():
return
self.__stopped.set_result(True)
asyncio.create_task(self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}")
self._stopped.set_result(True)
asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}')
def _dispatch_item(self, item: Item, interaction: Interaction):
if self.__stopped.done():
return
asyncio.create_task(self._scheduled_task(item, interaction), name=f"discord-ui-view-dispatch-{self.id}")
def dispatch(self, state: Any, item: Item, interaction: Interaction):
asyncio.create_task(self._scheduled_task(state, item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
def refresh(self, components: List[Component]):
# This is pretty hacky at the moment
@ -407,25 +349,23 @@ class View:
This operation cannot be undone.
"""
if not self.__stopped.done():
self.__stopped.set_result(False)
if not self._stopped.done():
self._stopped.set_result(False)
self.__timeout_expiry = None
if self.__timeout_task is not None:
self.__timeout_task.cancel()
self.__timeout_task = None
if self._timeout_handler:
self._timeout_handler.cancel()
if self.__cancel_callback:
self.__cancel_callback(self)
self.__cancel_callback = None
if self._cancel_callback:
self._cancel_callback(self)
self._cancel_callback = None
def is_finished(self) -> bool:
""":class:`bool`: Whether the view has finished interacting."""
return self.__stopped.done()
return self._stopped.done()
def is_dispatching(self) -> bool:
""":class:`bool`: Whether the view has been added for dispatching purposes."""
return self.__cancel_callback is not None
return self._cancel_callback is not None
def is_persistent(self) -> bool:
""":class:`bool`: Whether the view is set up as persistent.
@ -447,32 +387,31 @@ class View:
If ``True``, then the view timed out. If ``False`` then
the view finished normally.
"""
return await self.__stopped
return await self._stopped
class ViewStore:
def __init__(self, state: ConnectionState):
# (component_type, message_id, custom_id): (View, Item)
self._views: Dict[Tuple[int, Optional[int], str], Tuple[View, Item]] = {}
# (component_type, custom_id): (View, Item, Expiry)
self._views: Dict[Tuple[int, str], Tuple[View, Item, Optional[float]]] = {}
# message_id: View
self._synced_message_views: Dict[int, View] = {}
self._state: ConnectionState = state
@property
def persistent_views(self) -> Sequence[View]:
# fmt: off
views = {
view.id: view
for (_, (view, _)) in self._views.items()
for (_, (view, _, _)) in self._views.items()
if view.is_persistent()
}
# fmt: on
return list(views.values())
def __verify_integrity(self):
to_remove: List[Tuple[int, Optional[int], str]] = []
for (k, (view, _)) in self._views.items():
if view.is_finished():
to_remove: List[Tuple[int, str]] = []
now = time.monotonic()
for (k, (_, _, expiry)) in self._views.items():
if expiry is not None and now >= expiry:
to_remove.append(k)
for k in to_remove:
@ -481,10 +420,11 @@ class ViewStore:
def add_view(self, view: View, message_id: Optional[int] = None):
self.__verify_integrity()
view._start_listening_from_store(self)
expiry = view._expires_at
view._start_listening(self)
for item in view.children:
if item.is_dispatchable():
self._views[(item.type.value, message_id, item.custom_id)] = (view, item) # type: ignore
self._views[(item.type.value, item.custom_id)] = (view, item, expiry) # type: ignore
if message_id is not None:
self._synced_message_views[message_id] = view
@ -501,17 +441,15 @@ class ViewStore:
def dispatch(self, component_type: int, custom_id: str, interaction: Interaction):
self.__verify_integrity()
message_id: Optional[int] = interaction.message and interaction.message.id
key = (component_type, message_id, custom_id)
# Fallback to None message_id searches in case a persistent view
# was added without an associated message_id
value = self._views.get(key) or self._views.get((component_type, None, custom_id))
key = (component_type, custom_id)
value = self._views.get(key)
if value is None:
return
view, item = value
view, item, _ = value
self._views[key] = (view, item, view._expires_at)
item.refresh_state(interaction)
view._dispatch_item(item, interaction)
view.dispatch(self._state, item, interaction)
def is_message_tracked(self, message_id: int):
return message_id in self._synced_message_views

View File

@ -22,54 +22,24 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, TypeVar, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING
import discord.abc
from .asset import Asset
from .colour import Colour
from .enums import DefaultAvatar
from .flags import PublicUserFlags
from .utils import snowflake_time, _bytes_to_base64_data, MISSING
if TYPE_CHECKING:
from datetime import datetime
from .channel import DMChannel
from .guild import Guild
from .message import Message
from .state import ConnectionState
from .types.channel import DMChannel as DMChannelPayload
from .types.user import User as UserPayload
from .utils import snowflake_time, _bytes_to_base64_data
from .enums import DefaultAvatar
from .colour import Colour
from .asset import Asset
__all__ = (
"User",
"ClientUser",
'User',
'ClientUser',
)
BU = TypeVar("BU", bound="BaseUser")
_BaseUser = discord.abc.User
class _UserTag:
__slots__ = ()
id: int
class BaseUser(_UserTag):
__slots__ = (
"name",
"id",
"discriminator",
"_avatar",
"_banner",
"_accent_colour",
"bot",
"system",
"_public_flags",
"_state",
)
class BaseUser(_BaseUser):
__slots__ = ('name', 'id', 'discriminator', '_avatar', 'bot', 'system', '_public_flags', '_state')
if TYPE_CHECKING:
name: str
@ -77,150 +47,84 @@ class BaseUser(_UserTag):
discriminator: str
bot: bool
system: bool
_state: ConnectionState
_avatar: Optional[str]
_banner: Optional[str]
_accent_colour: Optional[int]
_public_flags: int
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
def __init__(self, *, state, data):
self._state = state
self._update(data)
def __repr__(self) -> str:
def __repr__(self):
return (
f"<BaseUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
f" bot={self.bot} system={self.system}>"
)
def __str__(self) -> str:
return f"{self.name}#{self.discriminator}"
def __str__(self):
return f'{self.name}#{self.discriminator}'
def __int__(self) -> int:
return self.id
def __eq__(self, other):
return isinstance(other, _BaseUser) and other.id == self.id
def __eq__(self, other: Any) -> bool:
return isinstance(other, _UserTag) and other.id == self.id
def __ne__(self, other: Any) -> bool:
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self) -> int:
def __hash__(self):
return self.id >> 22
def _update(self, data: UserPayload) -> None:
self.name = data["username"]
self.id = int(data["id"])
self.discriminator = data["discriminator"]
self._avatar = data["avatar"]
self._banner = data.get("banner", None)
self._accent_colour = data.get("accent_color", None)
self._public_flags = data.get("public_flags", 0)
self.bot = data.get("bot", False)
self.system = data.get("system", False)
def _update(self, data):
self.name = data['username']
self.id = int(data['id'])
self.discriminator = data['discriminator']
self._avatar = data['avatar']
self._public_flags = data.get('public_flags', 0)
self.bot = data.get('bot', False)
self.system = data.get('system', False)
@classmethod
def _copy(cls: Type[BU], user: BU) -> BU:
def _copy(cls, user):
self = cls.__new__(cls) # bypass __init__
self.name = user.name
self.id = user.id
self.discriminator = user.discriminator
self._avatar = user._avatar
self._banner = user._banner
self._accent_colour = user._accent_colour
self.bot = user.bot
self._state = user._state
self._public_flags = user._public_flags
return self
def _to_minimal_user_json(self) -> Dict[str, Any]:
def _to_minimal_user_json(self):
return {
"username": self.name,
"id": self.id,
"avatar": self._avatar,
"discriminator": self.discriminator,
"bot": self.bot,
'username': self.name,
'id': self.id,
'avatar': self._avatar,
'discriminator': self.discriminator,
'bot': self.bot,
}
@property
def public_flags(self) -> PublicUserFlags:
def public_flags(self):
""":class:`PublicUserFlags`: The publicly available flags the user has."""
return PublicUserFlags._from_value(self._public_flags)
@property
def avatar(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the avatar the user has.
def avatar(self):
""":class:`Asset`: Returns an :class:`Asset` for the avatar the user has.
If the user does not have a traditional avatar, ``None`` is returned.
If you want the avatar that a user has displayed, consider :attr:`display_avatar`.
If the user does not have a traditional avatar, an asset for
the default avatar is returned instead.
"""
if self._avatar is not None:
if self._avatar is None:
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
else:
return Asset._from_avatar(self._state, self.id, self._avatar)
return None
@property
def default_avatar(self) -> Asset:
def default_avatar(self):
""":class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
@property
def display_avatar(self) -> Asset:
""":class:`Asset`: Returns the user's display avatar.
For regular users this is just their default avatar or uploaded avatar.
.. versionadded:: 2.0
"""
return self.avatar or self.default_avatar
@property
def banner(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns the user's banner asset, if available.
.. versionadded:: 2.0
.. note::
This information is only available via :meth:`Client.fetch_user`.
"""
if self._banner is None:
return None
return Asset._from_user_banner(self._state, self.id, self._banner)
@property
def accent_colour(self) -> Optional[Colour]:
"""Optional[:class:`Colour`]: Returns the user's accent colour, if applicable.
There is an alias for this named :attr:`accent_color`.
.. versionadded:: 2.0
.. note::
This information is only available via :meth:`Client.fetch_user`.
"""
if self._accent_colour is None:
return None
return Colour(self._accent_colour)
@property
def accent_color(self) -> Optional[Colour]:
"""Optional[:class:`Colour`]: Returns the user's accent color, if applicable.
There is an alias for this named :attr:`accent_colour`.
.. versionadded:: 2.0
.. note::
This information is only available via :meth:`Client.fetch_user`.
"""
return self.accent_colour
@property
def colour(self) -> Colour:
def colour(self):
""":class:`Colour`: A property that returns a colour denoting the rendered colour
for the user. This always returns :meth:`Colour.default`.
@ -229,7 +133,7 @@ class BaseUser(_UserTag):
return Colour.default()
@property
def color(self) -> Colour:
def color(self):
""":class:`Colour`: A property that returns a color denoting the rendered color
for the user. This always returns :meth:`Colour.default`.
@ -238,12 +142,12 @@ class BaseUser(_UserTag):
return self.colour
@property
def mention(self) -> str:
def mention(self):
""":class:`str`: Returns a string that allows you to mention the given user."""
return f"<@{self.id}>"
return f'<@{self.id}>'
@property
def created_at(self) -> datetime:
def created_at(self):
""":class:`datetime.datetime`: Returns the user's creation time in UTC.
This is when the user's Discord account was created.
@ -251,7 +155,7 @@ class BaseUser(_UserTag):
return snowflake_time(self.id)
@property
def display_name(self) -> str:
def display_name(self):
""":class:`str`: Returns the user's display name.
For regular users this is just their username, but
@ -260,7 +164,7 @@ class BaseUser(_UserTag):
"""
return self.name
def mentioned_in(self, message: Message) -> bool:
def mentioned_in(self, message):
"""Checks if the user is mentioned in the specified message.
Parameters
@ -324,32 +228,26 @@ class ClientUser(BaseUser):
Specifies if the user has MFA turned on and working.
"""
__slots__ = ("locale", "_flags", "verified", "mfa_enabled", "__weakref__")
__slots__ = BaseUser.__slots__ + ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__')
if TYPE_CHECKING:
verified: bool
locale: Optional[str]
mfa_enabled: bool
_flags: int
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
def __init__(self, *, state, data):
super().__init__(state=state, data=data)
def __repr__(self) -> str:
def __repr__(self):
return (
f"<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
f" bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>"
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}'
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>'
)
def _update(self, data: UserPayload) -> None:
def _update(self, data):
super()._update(data)
# There's actually an Optional[str] phone field as well but I won't use it
self.verified = data.get("verified", False)
self.locale = data.get("locale")
self._flags = data.get("flags", 0)
self.mfa_enabled = data.get("mfa_enabled", False)
self.verified = data.get('verified', False)
self.locale = data.get('locale')
self._flags = data.get('flags', 0)
self.mfa_enabled = data.get('mfa_enabled', False)
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser:
async def edit(self, *, username: str = None, avatar: Optional[bytes] = None) -> None:
"""|coro|
Edits the current profile of the client.
@ -363,9 +261,6 @@ class ClientUser(BaseUser):
The only image formats supported for uploading is JPEG and PNG.
.. versionchanged:: 2.0
The edit is no longer in-place, instead the newly edited client user is returned.
Parameters
-----------
username: :class:`str`
@ -380,21 +275,13 @@ class ClientUser(BaseUser):
Editing your profile failed.
InvalidArgument
Wrong image format passed for ``avatar``.
Returns
---------
:class:`ClientUser`
The newly edited client user.
"""
payload: Dict[str, Any] = {}
if username is not MISSING:
payload["username"] = username
if avatar is not MISSING:
payload["avatar"] = _bytes_to_base64_data(avatar)
if avatar is not None:
avatar = _bytes_to_base64_data(avatar)
data: UserPayload = await self._state.http.edit_profile(payload)
return ClientUser(state=self._state, data=data)
data = await self._state.http.edit_profile(username=username, avatar=avatar)
self._update(data)
class User(BaseUser, discord.abc.Messageable):
@ -418,10 +305,6 @@ class User(BaseUser, discord.abc.Messageable):
Returns the user's name with discriminator.
.. describe:: int(x)
Returns the user's ID.
Attributes
-----------
name: :class:`str`
@ -436,34 +319,17 @@ class User(BaseUser, discord.abc.Messageable):
Specifies if the user is a system user (i.e. represents Discord officially).
"""
__slots__ = ("_stored",)
__slots__ = BaseUser.__slots__ + ('__weakref__',)
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
super().__init__(state=state, data=data)
self._stored: bool = False
def __repr__(self):
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
def __repr__(self) -> str:
return f"<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>"
def __del__(self) -> None:
try:
if self._stored:
self._state.deref_user(self.id)
except Exception:
pass
@classmethod
def _copy(cls, user: User):
self = super()._copy(user)
self._stored = False
return self
async def _get_channel(self) -> DMChannel:
async def _get_channel(self):
ch = await self.create_dm()
return ch
@property
def dm_channel(self) -> Optional[DMChannel]:
def dm_channel(self):
"""Optional[:class:`DMChannel`]: Returns the channel associated with this user if it exists.
If this returns ``None``, you can create a DM channel by calling the
@ -472,7 +338,7 @@ class User(BaseUser, discord.abc.Messageable):
return self._state._get_private_channel_by_user(self.id)
@property
def mutual_guilds(self) -> List[Guild]:
def mutual_guilds(self):
"""List[:class:`Guild`]: The guilds that the user shares with the client.
.. note::
@ -483,7 +349,7 @@ class User(BaseUser, discord.abc.Messageable):
"""
return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)]
async def create_dm(self) -> DMChannel:
async def create_dm(self):
"""|coro|
Creates a :class:`DMChannel` with this user.
@ -501,5 +367,5 @@ class User(BaseUser, discord.abc.Messageable):
return found
state = self._state
data: DMChannelPayload = await state.http.start_private_message(self.id)
data = await state.http.start_private_message(self.id)
return state.add_dm_channel(data)

View File

@ -63,27 +63,18 @@ import warnings
from .errors import InvalidArgument
try:
import orjson
except ModuleNotFoundError:
HAS_ORJSON = False
else:
HAS_ORJSON = True
__all__ = (
"oauth_url",
"snowflake_time",
"time_snowflake",
"find",
"get",
"sleep_until",
"utcnow",
"remove_markdown",
"escape_markdown",
"escape_mentions",
"as_chunks",
"format_dt",
'oauth_url',
'snowflake_time',
'time_snowflake',
'find',
'get',
'sleep_until',
'utcnow',
'remove_markdown',
'escape_markdown',
'escape_mentions',
'as_chunks',
)
DISCORD_EPOCH = 1420070400000
@ -97,7 +88,7 @@ class _MissingSentinel:
return False
def __repr__(self):
return "..."
return '...'
MISSING: Any = _MissingSentinel()
@ -106,7 +97,7 @@ MISSING: Any = _MissingSentinel()
class _cached_property:
def __init__(self, function):
self.function = function
self.__doc__ = getattr(function, "__doc__")
self.__doc__ = getattr(function, '__doc__')
def __get__(self, instance, owner):
if instance is None:
@ -120,9 +111,6 @@ class _cached_property:
if TYPE_CHECKING:
from functools import cached_property as cached_property
from typing_extensions import ParamSpec
from .permissions import Permissions
from .abc import Snowflake
from .invite import Invite
@ -131,14 +119,13 @@ if TYPE_CHECKING:
class _RequestLike(Protocol):
headers: Mapping[str, Any]
P = ParamSpec("P")
else:
cached_property = _cached_property
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
_Iter = Union[Iterator[T], AsyncIterator[T]]
@ -146,7 +133,7 @@ class CachedSlotProperty(Generic[T, T_co]):
def __init__(self, name: str, function: Callable[[T], T_co]) -> None:
self.name = name
self.function = function
self.__doc__ = getattr(function, "__doc__")
self.__doc__ = getattr(function, '__doc__')
@overload
def __get__(self, instance: None, owner: Type[T]) -> CachedSlotProperty[T, T_co]:
@ -176,7 +163,7 @@ class classproperty(Generic[T_co]):
return self.fget(owner)
def __set__(self, instance, value) -> None:
raise AttributeError("cannot set attribute")
raise AttributeError('cannot set attribute')
def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]:
@ -235,8 +222,8 @@ def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
return None
def copy_doc(original: Callable) -> Callable[[T], T]:
def decorator(overriden: T) -> T:
def copy_doc(original: Callable[..., Any]) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(overriden: Callable[..., Any]) -> Callable[..., Any]:
overriden.__doc__ = original.__doc__
overriden.__signature__ = _signature(original) # type: ignore
return overriden
@ -244,18 +231,18 @@ def copy_doc(original: Callable) -> Callable[[T], T]:
return decorator
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Callable[P, T]]:
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[..., T]], Callable[..., T]]:
def actual_decorator(func: Callable[..., T]) -> Callable[..., T]:
@functools.wraps(func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> T:
warnings.simplefilter("always", DeprecationWarning) # turn off filter
def decorated(*args, **kwargs) -> T:
warnings.simplefilter('always', DeprecationWarning) # turn off filter
if instead:
fmt = "{0.__name__} is deprecated, use {1} instead."
else:
fmt = "{0.__name__} is deprecated."
fmt = '{0.__name__} is deprecated.'
warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning)
warnings.simplefilter("default", DeprecationWarning) # reset filter
warnings.simplefilter('default', DeprecationWarning) # reset filter
return func(*args, **kwargs)
return decorated
@ -264,20 +251,18 @@ def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Call
def oauth_url(
client_id: Union[int, str],
*,
permissions: Permissions = MISSING,
guild: Snowflake = MISSING,
redirect_uri: str = MISSING,
scopes: Iterable[str] = MISSING,
disable_guild_select: bool = False,
) -> str:
client_id: str,
permissions: Optional[Permissions] = None,
guild: Optional[Snowflake] = None,
redirect_uri: Optional[str] = None,
scopes: Optional[Iterable[str]] = None,
):
"""A helper function that returns the OAuth2 URL for inviting the bot
into guilds.
Parameters
-----------
client_id: Union[:class:`int`, :class:`str`]
client_id: :class:`str`
The client ID for your bot.
permissions: :class:`~discord.Permissions`
The permissions you're requesting. If not given then you won't be requesting any
@ -290,28 +275,22 @@ def oauth_url(
An optional valid list of scopes. Defaults to ``('bot',)``.
.. versionadded:: 1.7
disable_guild_select: :class:`bool`
Whether to disallow the user from changing the guild dropdown.
.. versionadded:: 2.0
Returns
--------
:class:`str`
The OAuth2 URL for inviting the bot into guilds.
"""
url = f"https://discord.com/oauth2/authorize?client_id={client_id}"
url += "&scope=" + "+".join(scopes or ("bot",))
if permissions is not MISSING:
url += f"&permissions={permissions.value}"
if guild is not MISSING:
url += f"&guild_id={guild.id}"
if redirect_uri is not MISSING:
url = f'https://discord.com/oauth2/authorize?client_id={client_id}'
url = url + '&scope=' + '+'.join(scopes or ('bot',))
if permissions is not None:
url = url + '&permissions=' + str(permissions.value)
if guild is not None:
url = url + "&guild_id=" + str(guild.id)
if redirect_uri is not None:
from urllib.parse import urlencode
url += "&response_type=code&" + urlencode({"redirect_uri": redirect_uri})
if disable_guild_select:
url += "&disable_guild_select=true"
url = url + "&response_type=code&" + urlencode({'redirect_uri': redirect_uri})
return url
@ -328,7 +307,7 @@ def snowflake_time(id: int) -> datetime.datetime:
An aware datetime in UTC representing the creation time of the snowflake.
"""
timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000
return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=datetime.timezone.utc)
def time_snowflake(dt: datetime.datetime, high: bool = False) -> int:
@ -434,13 +413,13 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
# Special case the single element call
if len(attrs) == 1:
k, v = attrs.popitem()
pred = attrget(k.replace("__", "."))
pred = attrget(k.replace('__', '.'))
for elem in iterable:
if pred(elem) == v:
return elem
return None
converted = [(attrget(attr.replace("__", ".")), value) for attr, value in attrs.items()]
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()]
for elem in iterable:
if _all(pred(elem) == value for pred, value in converted):
@ -462,46 +441,35 @@ def _get_as_snowflake(data: Any, key: str) -> Optional[int]:
def _get_mime_type_for_image(data: bytes):
if data.startswith(b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"):
return "image/png"
elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"):
return "image/jpeg"
elif data.startswith((b"\x47\x49\x46\x38\x37\x61", b"\x47\x49\x46\x38\x39\x61")):
return "image/gif"
elif data.startswith(b"RIFF") and data[8:12] == b"WEBP":
return "image/webp"
if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'):
return 'image/png'
elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'):
return 'image/jpeg'
elif data.startswith((b'\x47\x49\x46\x38\x37\x61', b'\x47\x49\x46\x38\x39\x61')):
return 'image/gif'
elif data.startswith(b'RIFF') and data[8:12] == b'WEBP':
return 'image/webp'
else:
raise InvalidArgument("Unsupported image type given")
raise InvalidArgument('Unsupported image type given')
def _bytes_to_base64_data(data: bytes) -> str:
fmt = "data:{mime};base64,{data}"
fmt = 'data:{mime};base64,{data}'
mime = _get_mime_type_for_image(data)
b64 = b64encode(data).decode("ascii")
b64 = b64encode(data).decode('ascii')
return fmt.format(mime=mime, data=b64)
if HAS_ORJSON:
def _to_json(obj: Any) -> str: # type: ignore
return orjson.dumps(obj).decode("utf-8")
_from_json = orjson.loads # type: ignore
else:
def _to_json(obj: Any) -> str:
return json.dumps(obj, separators=(",", ":"), ensure_ascii=True)
_from_json = json.loads
def to_json(obj: Any) -> str:
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get("X-Ratelimit-Reset-After")
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
if use_clock or not reset_after:
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
else:
return float(reset_after)
@ -611,7 +579,7 @@ class SnowflakeList(array.array):
...
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) # type: ignore
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
def add(self, element: int) -> None:
i = bisect_left(self, element)
@ -626,7 +594,7 @@ class SnowflakeList(array.array):
return i != len(self) and self[i] == element
_IS_ASCII = re.compile(r"^[\x00-\x7f]+$")
_IS_ASCII = re.compile(r'^[\x00-\x7f]+$')
def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
@ -635,7 +603,7 @@ def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
if match:
return match.endpos
UNICODE_WIDE_CHAR_TYPE = "WFA"
UNICODE_WIDE_CHAR_TYPE = 'WFA'
func = unicodedata.east_asian_width
return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string)
@ -659,7 +627,7 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
if isinstance(invite, Invite):
return invite.code
else:
rx = r"(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)"
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
m = re.match(rx, invite)
if m:
return m.group(1)
@ -687,24 +655,22 @@ def resolve_template(code: Union[Template, str]) -> str:
if isinstance(code, Template):
return code.code
else:
rx = r"(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)"
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
m = re.match(rx, code)
if m:
return m.group(1)
return code
_MARKDOWN_ESCAPE_SUBREGEX = "|".join(r"\{0}(?=([\s\S]*((?<!\{0})\{0})))".format(c) for c in ("*", "`", "_", "~", "|"))
_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) for c in ('*', '`', '_', '~', '|'))
_MARKDOWN_ESCAPE_COMMON = r"^>(?:>>)?\s|\[.+\]\(.+\)"
_MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)'
_MARKDOWN_ESCAPE_REGEX = re.compile(
fr"(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})", re.MULTILINE
)
_MARKDOWN_ESCAPE_REGEX = re.compile(fr'(?P<markdown>{_MARKDOWN_ESCAPE_SUBREGEX}|{_MARKDOWN_ESCAPE_COMMON})', re.MULTILINE)
_URL_REGEX = r"(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])"
_URL_REGEX = r'(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\'\]\s])'
_MARKDOWN_STOCK_REGEX = fr"(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})"
_MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})'
def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
@ -733,11 +699,11 @@ def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
def replacement(match):
groupdict = match.groupdict()
return groupdict.get("url", "")
return groupdict.get('url', '')
regex = _MARKDOWN_STOCK_REGEX
if ignore_links:
regex = f"(?:{_URL_REGEX}|{regex})"
regex = f'(?:{_URL_REGEX}|{regex})'
return re.sub(regex, replacement, text, 0, re.MULTILINE)
@ -770,18 +736,18 @@ def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool =
def replacement(match):
groupdict = match.groupdict()
is_url = groupdict.get("url")
is_url = groupdict.get('url')
if is_url:
return is_url
return "\\" + groupdict["markdown"]
return '\\' + groupdict['markdown']
regex = _MARKDOWN_STOCK_REGEX
if ignore_links:
regex = f"(?:{_URL_REGEX}|{regex})"
regex = f'(?:{_URL_REGEX}|{regex})'
return re.sub(regex, replacement, text, 0, re.MULTILINE)
else:
text = re.sub(r"\\", r"\\\\", text)
return _MARKDOWN_ESCAPE_REGEX.sub(r"\\\1", text)
text = re.sub(r'\\', r'\\\\', text)
return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text)
def escape_mentions(text: str) -> str:
@ -807,7 +773,7 @@ def escape_mentions(text: str) -> str:
:class:`str`
The text with the mentions removed.
"""
return re.sub(r"@(everyone|here|[!&]?[0-9]{17,20})", "@\u200b\\1", text)
return re.sub(r'@(everyone|here|[!&]?[0-9]{17,20})', '@\u200b\\1', text)
def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
@ -871,7 +837,7 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
A new iterator which yields chunks of a given size.
"""
if max_size <= 0:
raise ValueError("Chunk sizes must be greater than 0.")
raise ValueError('Chunk sizes must be greater than 0.')
if isinstance(iterator, AsyncIterator):
return _achunk(iterator, max_size)
@ -917,12 +883,12 @@ def evaluate_annotation(
cache[tp] = evaluated
return evaluate_annotation(evaluated, globals, locals, cache)
if hasattr(tp, "__args__"):
if hasattr(tp, '__args__'):
implicit_str = True
is_literal = False
args = tp.__args__
if not hasattr(tp, "__origin__"):
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
if not hasattr(tp, '__origin__'):
if PY_310 and tp.__class__ is types.Union:
converted = Union[args] # type: ignore
return evaluate_annotation(converted, globals, locals, cache)
@ -939,12 +905,10 @@ def evaluate_annotation(
implicit_str = False
is_literal = True
evaluated_args = tuple(
evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
)
evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.")
raise TypeError('Literal arguments must be of type str, int, bool, or NoneType.')
if evaluated_args == args:
return tp
@ -972,51 +936,3 @@ def resolve_annotation(
if cache is None:
cache = {}
return evaluate_annotation(annotation, globalns, locals, cache)
TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"]
def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) -> str:
"""A helper function to format a :class:`datetime.datetime` for presentation within Discord.
This allows for a locale-independent way of presenting data using Discord specific Markdown.
+-------------+----------------------------+-----------------+
| Style | Example Output | Description |
+=============+============================+=================+
| t | 22:57 | Short Time |
+-------------+----------------------------+-----------------+
| T | 22:57:58 | Long Time |
+-------------+----------------------------+-----------------+
| d | 17/05/2016 | Short Date |
+-------------+----------------------------+-----------------+
| D | 17 May 2016 | Long Date |
+-------------+----------------------------+-----------------+
| f (default) | 17 May 2016 22:57 | Short Date Time |
+-------------+----------------------------+-----------------+
| F | Tuesday, 17 May 2016 22:57 | Long Date Time |
+-------------+----------------------------+-----------------+
| R | 5 years ago | Relative Time |
+-------------+----------------------------+-----------------+
Note that the exact output depends on the user's locale setting in the client. The example output
presented is using the ``en-GB`` locale.
.. versionadded:: 2.0
Parameters
-----------
dt: :class:`datetime.datetime`
The datetime to format.
style: :class:`str`
The style to format the datetime with.
Returns
--------
:class:`str`
The formatted string.
"""
if style is None:
return f"<t:{int(dt.timestamp())}>"
return f"<t:{int(dt.timestamp())}:{style}>"

View File

@ -20,9 +20,9 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
Some documentation to refer to:
"""Some documentation to refer to:
- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID.
- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE.
@ -37,54 +37,31 @@ Some documentation to refer to:
- Finally we can transmit data to endpoint:port.
"""
from __future__ import annotations
import asyncio
import socket
import logging
import struct
import threading
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple
from typing import Any, Callable
from . import opus, utils
from .backoff import ExponentialBackoff
from .gateway import *
from .errors import ClientException, ConnectionClosed
from .player import AudioPlayer, AudioSource
from .utils import MISSING
if TYPE_CHECKING:
from .client import Client
from .guild import Guild
from .state import ConnectionState
from .user import ClientUser
from .opus import Encoder
from . import abc
from .types.voice import (
GuildVoiceState as GuildVoiceStatePayload,
VoiceServerUpdate as VoiceServerUpdatePayload,
SupportedModes,
)
has_nacl: bool
try:
import nacl.secret # type: ignore
import nacl.secret
has_nacl = True
except ImportError:
has_nacl = False
__all__ = (
"VoiceProtocol",
"VoiceClient",
'VoiceProtocol',
'VoiceClient',
)
_log = logging.getLogger(__name__)
log = logging.getLogger(__name__)
class VoiceProtocol:
"""A class that represents the Discord voice protocol.
@ -107,11 +84,11 @@ class VoiceProtocol:
The voice channel that is being connected to.
"""
def __init__(self, client: Client, channel: abc.Connectable) -> None:
self.client: Client = client
self.channel: abc.Connectable = channel
def __init__(self, client, channel):
self.client = client
self.channel = channel
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
async def on_voice_state_update(self, data):
"""|coro|
An abstract method that is called when the client's voice state
@ -128,7 +105,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
async def on_voice_server_update(self, data):
"""|coro|
An abstract method that is called when initially connecting to voice.
@ -145,7 +122,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
async def connect(self, *, timeout: float, reconnect: bool) -> None:
async def connect(self, *, timeout: float, reconnect: bool):
"""|coro|
An abstract method called when the client initiates the connection request.
@ -168,7 +145,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
async def disconnect(self, *, force: bool) -> None:
async def disconnect(self, *, force: bool):
"""|coro|
An abstract method called when the client terminates the connection.
@ -182,7 +159,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
def cleanup(self) -> None:
def cleanup(self):
"""This method *must* be called to ensure proper clean-up during a disconnect.
It is advisable to call this from within :meth:`disconnect` when you are
@ -195,7 +172,6 @@ class VoiceProtocol:
key_id, _ = self.channel._get_voice_client_key()
self.client._connection._remove_voice_client(key_id)
class VoiceClient(VoiceProtocol):
"""Represents a Discord voice connection.
@ -222,57 +198,48 @@ class VoiceClient(VoiceProtocol):
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
"""
endpoint_ip: str
voice_port: int
secret_key: List[int]
ssrc: int
def __init__(self, client: Client, channel: abc.Connectable):
def __init__(self, client, channel):
if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice")
super().__init__(client, channel)
state = client._connection
self.token: str = MISSING
self.socket = MISSING
self.loop: asyncio.AbstractEventLoop = state.loop
self._state: ConnectionState = state
self.token = None
self.socket = None
self.loop = state.loop
self._state = state
# this will be used in the AudioPlayer thread
self._connected: threading.Event = threading.Event()
self._connected = threading.Event()
self._handshaking: bool = False
self._potentially_reconnecting: bool = False
self._voice_state_complete: asyncio.Event = asyncio.Event()
self._voice_server_complete: asyncio.Event = asyncio.Event()
self._handshaking = False
self._potentially_reconnecting = False
self._voice_state_complete = asyncio.Event()
self._voice_server_complete = asyncio.Event()
self.mode: str = MISSING
self._connections: int = 0
self.sequence: int = 0
self.timestamp: int = 0
self.timeout: float = 0
self._runner: asyncio.Task = MISSING
self._player: Optional[AudioPlayer] = None
self.encoder: Encoder = MISSING
self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING
self.ip: str = MISSING
self.port: Tuple[Any, ...] = MISSING
self.mode = None
self._connections = 0
self.sequence = 0
self.timestamp = 0
self._runner = None
self._player = None
self.encoder = None
self._lite_nonce = 0
self.ws = None
warn_nacl = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = (
"xsalsa20_poly1305_lite",
"xsalsa20_poly1305_suffix",
"xsalsa20_poly1305",
supported_modes = (
'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix',
'xsalsa20_poly1305',
)
@property
def guild(self) -> Optional[Guild]:
def guild(self):
"""Optional[:class:`Guild`]: The guild we're connected to, if applicable."""
return getattr(self.channel, "guild", None)
return getattr(self.channel, 'guild', None)
@property
def user(self) -> ClientUser:
def user(self):
""":class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
return self._state.user
@ -285,9 +252,9 @@ class VoiceClient(VoiceProtocol):
# connection related
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data["session_id"]
channel_id = data["channel_id"]
async def on_voice_state_update(self, data):
self.session_id = data['session_id']
channel_id = data['channel_id']
if not self._handshaking or self._potentially_reconnecting:
# If we're done handshaking then we just need to update ourselves
@ -298,33 +265,31 @@ class VoiceClient(VoiceProtocol):
await self.disconnect()
else:
guild = self.guild
self.channel = channel_id and guild and guild.get_channel(int(channel_id)) # type: ignore
self.channel = channel_id and guild and guild.get_channel(int(channel_id))
else:
self._voice_state_complete.set()
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
async def on_voice_server_update(self, data):
if self._voice_server_complete.is_set():
_log.info("Ignoring extraneous voice server update.")
log.info('Ignoring extraneous voice server update.')
return
self.token = data.get("token")
self.server_id = int(data["guild_id"])
endpoint = data.get("endpoint")
self.token = data.get('token')
self.server_id = int(data['guild_id'])
endpoint = data.get('endpoint')
if endpoint is None or self.token is None:
_log.warning(
"Awaiting endpoint... This requires waiting. "
"If timeout occurred considering raising the timeout and reconnecting."
)
log.warning('Awaiting endpoint... This requires waiting. ' \
'If timeout occurred considering raising the timeout and reconnecting.')
return
self.endpoint, _, _ = endpoint.rpartition(":")
if self.endpoint.startswith("wss://"):
self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'):
# Just in case, strip it off since we're going to add it later
self.endpoint = self.endpoint[6:]
# This gets set later
self.endpoint_ip = MISSING
self.endpoint_ip = None
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
@ -336,29 +301,27 @@ class VoiceClient(VoiceProtocol):
self._voice_server_complete.set()
async def voice_connect(self) -> None:
async def voice_connect(self):
await self.channel.guild.change_voice_state(channel=self.channel)
async def voice_disconnect(self) -> None:
_log.info(
"The voice handshake is being terminated for Channel ID %s (Guild ID %s)", self.channel.id, self.guild.id
)
async def voice_disconnect(self):
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
await self.channel.guild.change_voice_state(channel=None)
def prepare_handshake(self) -> None:
def prepare_handshake(self):
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
_log.info("Starting voice handshake... (connection attempt %d)", self._connections + 1)
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
self._connections += 1
def finish_handshake(self) -> None:
_log.info("Voice handshake complete. Endpoint found %s", self.endpoint)
def finish_handshake(self):
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
async def connect_websocket(self) -> DiscordVoiceWebSocket:
async def connect_websocket(self):
ws = await DiscordVoiceWebSocket.from_client(self)
self._connected.clear()
while ws.secret_key is None:
@ -366,8 +329,8 @@ class VoiceClient(VoiceProtocol):
self._connected.set()
return ws
async def connect(self, *, reconnect: bool, timeout: float) -> None:
_log.info("Connecting to voice...")
async def connect(self, *, reconnect: bool, timeout: bool):
log.info('Connecting to voice...')
self.timeout = timeout
for i in range(5):
@ -395,17 +358,17 @@ class VoiceClient(VoiceProtocol):
break
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect:
_log.exception("Failed to connect to voice... Retrying...")
log.exception('Failed to connect to voice... Retrying...')
await asyncio.sleep(1 + i * 2.0)
await self.voice_disconnect()
continue
else:
raise
if self._runner is MISSING:
if self._runner is None:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
async def potential_reconnect(self) -> bool:
async def potential_reconnect(self):
# Attempt to stop the player thread from playing early
self._connected.clear()
self.prepare_handshake()
@ -428,7 +391,7 @@ class VoiceClient(VoiceProtocol):
return True
@property
def latency(self) -> float:
def latency(self):
""":class:`float`: Latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This could be referred to as the Discord Voice WebSocket latency and is
@ -440,7 +403,7 @@ class VoiceClient(VoiceProtocol):
return float("inf") if not ws else ws.latency
@property
def average_latency(self) -> float:
def average_latency(self):
""":class:`float`: Average of most recent 20 HEARTBEAT latencies in seconds.
.. versionadded:: 1.4
@ -448,7 +411,7 @@ class VoiceClient(VoiceProtocol):
ws = self.ws
return float("inf") if not ws else ws.average_latency
async def poll_voice_ws(self, reconnect: bool) -> None:
async def poll_voice_ws(self, reconnect):
backoff = ExponentialBackoff()
while True:
try:
@ -460,14 +423,14 @@ class VoiceClient(VoiceProtocol):
# 4014 - voice channel has been deleted.
# 4015 - voice server has crashed
if exc.code in (1000, 4015):
_log.info("Disconnecting from voice normally, close code %d.", exc.code)
log.info('Disconnecting from voice normally, close code %d.', exc.code)
await self.disconnect()
break
if exc.code == 4014:
_log.info("Disconnected from voice by force... potentially reconnecting.")
log.info('Disconnected from voice by force... potentially reconnecting.')
successful = await self.potential_reconnect()
if not successful:
_log.info("Reconnect was unsuccessful, disconnecting from voice normally...")
log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
await self.disconnect()
break
else:
@ -478,7 +441,7 @@ class VoiceClient(VoiceProtocol):
raise
retry = backoff.delay()
_log.exception("Disconnected from voice... Reconnecting in %.2fs.", retry)
log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
self._connected.clear()
await asyncio.sleep(retry)
await self.voice_disconnect()
@ -486,10 +449,10 @@ class VoiceClient(VoiceProtocol):
await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop.
_log.warning("Could not connect to voice... Retrying...")
log.warning('Could not connect to voice... Retrying...')
continue
async def disconnect(self, *, force: bool = False) -> None:
async def disconnect(self, *, force: bool = False):
"""|coro|
Disconnects this voice client from voice.
@ -510,7 +473,7 @@ class VoiceClient(VoiceProtocol):
if self.socket:
self.socket.close()
async def move_to(self, channel: abc.Snowflake) -> None:
async def move_to(self, channel):
"""|coro|
Moves you to a different voice channel.
@ -522,7 +485,7 @@ class VoiceClient(VoiceProtocol):
"""
await self.channel.guild.change_voice_state(channel=channel)
def is_connected(self) -> bool:
def is_connected(self):
"""Indicates if the voice client is connected to voice."""
return self._connected.is_set()
@ -534,36 +497,36 @@ class VoiceClient(VoiceProtocol):
# Formulate rtp header
header[0] = 0x80
header[1] = 0x78
struct.pack_into(">H", header, 2, self.sequence)
struct.pack_into(">I", header, 4, self.timestamp)
struct.pack_into(">I", header, 8, self.ssrc)
struct.pack_into('>H', header, 2, self.sequence)
struct.pack_into('>I', header, 4, self.timestamp)
struct.pack_into('>I', header, 8, self.ssrc)
encrypt_packet = getattr(self, "_encrypt_" + self.mode)
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
return encrypt_packet(header, data)
def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes:
def _encrypt_xsalsa20_poly1305(self, header, data):
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
nonce[:12] = header
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext
def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data) -> bytes:
def _encrypt_xsalsa20_poly1305_suffix(self, header, data):
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE)
return header + box.encrypt(bytes(data), nonce).ciphertext + nonce
def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data) -> bytes:
def _encrypt_xsalsa20_poly1305_lite(self, header, data):
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
nonce[:4] = struct.pack(">I", self._lite_nonce)
self.checked_add("_lite_nonce", 1, 4294967295)
nonce[:4] = struct.pack('>I', self._lite_nonce)
self.checked_add('_lite_nonce', 1, 4294967295)
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any] = None) -> None:
def play(self, source: AudioSource, *, after: Callable[[Exception], Any]=None):
"""Plays an :class:`AudioSource`.
The finalizer, ``after`` is called after the source has been exhausted
@ -577,7 +540,7 @@ class VoiceClient(VoiceProtocol):
-----------
source: :class:`AudioSource`
The audio source we're reading from.
after: Callable[[Optional[:class:`Exception`]], Any]
after: Callable[[:class:`Exception`], Any]
The finalizer that is called after the stream is exhausted.
This function must have a single parameter, ``error``, that
denotes an optional exception that was raised during playing.
@ -593,13 +556,13 @@ class VoiceClient(VoiceProtocol):
"""
if not self.is_connected():
raise ClientException("Not connected to voice.")
raise ClientException('Not connected to voice.')
if self.is_playing():
raise ClientException("Already playing audio.")
raise ClientException('Already playing audio.')
if not isinstance(source, AudioSource):
raise TypeError(f"source must be an AudioSource not {source.__class__.__name__}")
raise TypeError(f'source must an AudioSource not {source.__class__.__name__}')
if not self.encoder and not source.is_opus():
self.encoder = opus.Encoder()
@ -607,32 +570,32 @@ class VoiceClient(VoiceProtocol):
self._player = AudioPlayer(source, self, after=after)
self._player.start()
def is_playing(self) -> bool:
def is_playing(self):
"""Indicates if we're currently playing audio."""
return self._player is not None and self._player.is_playing()
def is_paused(self) -> bool:
def is_paused(self):
"""Indicates if we're playing audio, but if we're paused."""
return self._player is not None and self._player.is_paused()
def stop(self) -> None:
def stop(self):
"""Stops playing audio."""
if self._player:
self._player.stop()
self._player = None
def pause(self) -> None:
def pause(self):
"""Pauses the audio playing."""
if self._player:
self._player.pause()
def resume(self) -> None:
def resume(self):
"""Resumes the audio playing."""
if self._player:
self._player.resume()
@property
def source(self) -> Optional[AudioSource]:
def source(self):
"""Optional[:class:`AudioSource`]: The audio source being played, if playing.
This property can also be used to change the audio source currently being played.
@ -640,16 +603,16 @@ class VoiceClient(VoiceProtocol):
return self._player.source if self._player else None
@source.setter
def source(self, value: AudioSource) -> None:
def source(self, value):
if not isinstance(value, AudioSource):
raise TypeError(f"expected AudioSource not {value.__class__.__name__}.")
raise TypeError(f'expected AudioSource not {value.__class__.__name__}.')
if self._player is None:
raise ValueError("Not playing anything.")
raise ValueError('Not playing anything.')
self._player._set_source(value)
def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None:
def send_audio_packet(self, data, *, encode=True):
"""Sends an audio packet composed of the data.
You must be connected to play audio.
@ -669,7 +632,7 @@ class VoiceClient(VoiceProtocol):
Encoding the data failed.
"""
self.checked_add("sequence", 1, 65535)
self.checked_add('sequence', 1, 65535)
if encode:
encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME)
else:
@ -678,6 +641,6 @@ class VoiceClient(VoiceProtocol):
try:
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
except BlockingIOError:
_log.warning("A packet has been dropped (seq: %s, timestamp: %s)", self.sequence, self.timestamp)
log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
self.checked_add("timestamp", opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)

File diff suppressed because it is too large Load Diff

View File

@ -43,16 +43,16 @@ from .. import utils
from ..errors import InvalidArgument, HTTPException, Forbidden, NotFound, DiscordServerError
from ..message import Message
from ..http import Route
from ..channel import PartialMessageable
from ..object import Object
from .async_ import BaseWebhook, handle_message_parameters, _WebhookState
__all__ = (
"SyncWebhook",
"SyncWebhookMessage",
'SyncWebhook',
'SyncWebhookMessage',
)
_log = logging.getLogger(__name__)
log = logging.getLogger(__name__)
if TYPE_CHECKING:
from ..file import File
@ -116,14 +116,14 @@ class WebhookAdapter:
self._locks[bucket] = lock = threading.Lock()
if payload is not None:
headers["Content-Type"] = "application/json"
to_send = utils._to_json(payload)
headers['Content-Type'] = 'application/json'
to_send = utils.to_json(payload)
if auth_token is not None:
headers["Authorization"] = f"Bot {auth_token}"
headers['Authorization'] = f'Bot {auth_token}'
if reason is not None:
headers["X-Audit-Log-Reason"] = urlquote(reason, safe="/ ")
headers['X-Audit-Log-Reason'] = urlquote(reason, safe='/ ')
response: Optional[Response] = None
data: Optional[Union[Dict[str, Any], str]] = None
@ -140,38 +140,34 @@ class WebhookAdapter:
if multipart:
file_data = {}
for p in multipart:
name = p["name"]
if name == "payload_json":
to_send = {"payload_json": p["value"]}
name = p['name']
if name == 'payload_json':
to_send = { 'payload_json': p['value'] }
else:
file_data[name] = (p["filename"], p["value"], p["content_type"])
file_data[name] = (p['filename'], p['value'], p['content_type'])
try:
with session.request(
method, url, data=to_send, files=file_data, headers=headers, params=params
) as response:
_log.debug(
"Webhook ID %s with %s %s has returned status code %s",
with session.request(method, url, data=to_send, files=file_data, headers=headers, params=params) as response:
log.debug(
'Webhook ID %s with %s %s has returned status code %s',
webhook_id,
method,
url,
response.status_code,
)
response.encoding = "utf-8"
response.encoding = 'utf-8'
# Compatibility with aiohttp
response.status = response.status_code # type: ignore
data = response.text or None
if data and response.headers["Content-Type"] == "application/json":
if data and response.headers['Content-Type'] == 'application/json':
data = json.loads(data)
remaining = response.headers.get("X-Ratelimit-Remaining")
if remaining == "0" and response.status_code != 429:
remaining = response.headers.get('X-Ratelimit-Remaining')
if remaining == '0' and response.status_code != 429:
delta = utils._parse_ratelimit_header(response)
_log.debug(
"Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds",
webhook_id,
delta,
log.debug(
'Webhook ID %s has been pre-emptively rate limited, waiting %.2f seconds', webhook_id, delta
)
lock.delay_by(delta)
@ -179,13 +175,11 @@ class WebhookAdapter:
return data
if response.status_code == 429:
if not response.headers.get("Via"):
if not response.headers.get('Via'):
raise HTTPException(response, data)
retry_after: float = data["retry_after"] # type: ignore
_log.warning(
"Webhook ID %s is rate limited. Retrying in %.2f seconds", webhook_id, retry_after
)
retry_after: float = data['retry_after'] # type: ignore
log.warning('Webhook ID %s is rate limited. Retrying in %.2f seconds', webhook_id, retry_after)
time.sleep(retry_after)
continue
@ -211,7 +205,7 @@ class WebhookAdapter:
raise DiscordServerError(response, data)
raise HTTPException(response, data)
raise RuntimeError("Unreachable code in HTTP handling.")
raise RuntimeError('Unreachable code in HTTP handling.')
def delete_webhook(
self,
@ -221,7 +215,7 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route("DELETE", "/webhooks/{webhook_id}", webhook_id=webhook_id)
route = Route('DELETE', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, auth_token=token)
def delete_webhook_with_token(
@ -232,7 +226,7 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route("DELETE", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
route = Route('DELETE', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason)
def edit_webhook(
@ -244,7 +238,7 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id)
route = Route('PATCH', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session, reason=reason, payload=payload, auth_token=token)
def edit_webhook_with_token(
@ -256,7 +250,7 @@ class WebhookAdapter:
session: Session,
reason: Optional[str] = None,
):
route = Route("PATCH", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
route = Route('PATCH', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, reason=reason, payload=payload)
def execute_webhook(
@ -271,10 +265,10 @@ class WebhookAdapter:
thread_id: Optional[int] = None,
wait: bool = False,
):
params = {"wait": int(wait)}
params = {'wait': int(wait)}
if thread_id:
params["thread_id"] = thread_id
route = Route("POST", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
params['thread_id'] = thread_id
route = Route('POST', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session, payload=payload, multipart=multipart, files=files, params=params)
def get_webhook_message(
@ -286,8 +280,8 @@ class WebhookAdapter:
session: Session,
):
route = Route(
"GET",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
'GET',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@ -306,8 +300,8 @@ class WebhookAdapter:
files: Optional[List[File]] = None,
):
route = Route(
"PATCH",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
'PATCH',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@ -323,8 +317,8 @@ class WebhookAdapter:
session: Session,
):
route = Route(
"DELETE",
"/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}",
'DELETE',
'/webhooks/{webhook_id}/{webhook_token}/messages/{message_id}',
webhook_id=webhook_id,
webhook_token=token,
message_id=message_id,
@ -338,7 +332,7 @@ class WebhookAdapter:
*,
session: Session,
):
route = Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)
route = Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)
return self.request(route, session=session, auth_token=token)
def fetch_webhook_with_token(
@ -348,21 +342,12 @@ class WebhookAdapter:
*,
session: Session,
):
route = Route("GET", "/webhooks/{webhook_id}/{webhook_token}", webhook_id=webhook_id, webhook_token=token)
route = Route('GET', '/webhooks/{webhook_id}/{webhook_token}', webhook_id=webhook_id, webhook_token=token)
return self.request(route, session=session)
class _WebhookContext(threading.local):
adapter: Optional[WebhookAdapter] = None
_context = _WebhookContext()
def _get_webhook_adapter() -> WebhookAdapter:
if _context.adapter is None:
_context.adapter = WebhookAdapter()
return _context.adapter
_context = threading.local()
_context.adapter = WebhookAdapter()
class SyncWebhookMessage(Message):
@ -377,8 +362,6 @@ class SyncWebhookMessage(Message):
.. versionadded:: 2.0
"""
_state: _WebhookState
def edit(
self,
content: Optional[str] = MISSING,
@ -387,7 +370,7 @@ class SyncWebhookMessage(Message):
file: File = MISSING,
files: List[File] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
) -> SyncWebhookMessage:
):
"""Edits the message.
Parameters
@ -420,13 +403,8 @@ class SyncWebhookMessage(Message):
The length of ``embeds`` was invalid
InvalidArgument
There was no token associated with this webhook.
Returns
--------
:class:`SyncWebhookMessage`
The newly edited message.
"""
return self._state._webhook.edit_message(
self._state._webhook.edit_message(
self.id,
content=content,
embeds=embeds,
@ -479,10 +457,6 @@ class SyncWebhook(BaseWebhook):
Returns the webhooks's hash.
.. describe:: int(x)
Returns the webhooks's ID.
.. versionchanged:: 1.4
Webhooks are now comparable and hashable.
@ -520,24 +494,22 @@ class SyncWebhook(BaseWebhook):
.. versionadded:: 2.0
"""
__slots__: Tuple[str, ...] = ("session",)
__slots__: Tuple[str, ...] = BaseWebhook.__slots__ + ('session',)
def __init__(self, data: WebhookPayload, session: Session, token: Optional[str] = None, state=None):
super().__init__(data, token, state)
self.session = session
def __repr__(self):
return f"<Webhook id={self.id!r}>"
return f'<Webhook id={self.id!r}>'
@property
def url(self) -> str:
def url(self):
""":class:`str` : Returns the webhook's url."""
return f"https://discord.com/api/webhooks/{self.id}/{self.token}"
return f'https://discord.com/api/webhooks/{self.id}/{self.token}'
@classmethod
def partial(
cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None
) -> SyncWebhook:
def partial(cls, id: int, token: str, *, session: Session = MISSING, bot_token: Optional[str] = None):
"""Creates a partial :class:`Webhook`.
Parameters
@ -562,21 +534,21 @@ class SyncWebhook(BaseWebhook):
A partial webhook is just a webhook object with an ID and a token.
"""
data: WebhookPayload = {
"id": id,
"type": 1,
"token": token,
'id': id,
'type': 1,
'token': token,
}
import requests
if session is not MISSING:
if not isinstance(session, requests.Session):
raise TypeError(f"expected requests.Session not {session.__class__!r}")
raise TypeError(f'expected requests.Session not {session.__class__!r}')
else:
session = requests # type: ignore
return cls(data, session, token=bot_token)
@classmethod
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook:
def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None):
"""Creates a partial :class:`Webhook` from a webhook URL.
Parameters
@ -603,17 +575,17 @@ class SyncWebhook(BaseWebhook):
A partial :class:`Webhook`.
A partial webhook is just a webhook object with an ID and a token.
"""
m = re.search(r"discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})", url)
m = re.search(r'discord(?:app)?.com/api/webhooks/(?P<id>[0-9]{17,20})/(?P<token>[A-Za-z0-9\.\-\_]{60,68})', url)
if m is None:
raise InvalidArgument("Invalid webhook URL given.")
raise InvalidArgument('Invalid webhook URL given.')
data: Dict[str, Any] = m.groupdict()
data["type"] = 1
data['type'] = 1
import requests
if session is not MISSING:
if not isinstance(session, requests.Session):
raise TypeError(f"expected requests.Session not {session.__class__!r}")
raise TypeError(f'expected requests.Session not {session.__class__!r}')
else:
session = requests # type: ignore
return cls(data, session, token=bot_token) # type: ignore
@ -649,18 +621,18 @@ class SyncWebhook(BaseWebhook):
:class:`SyncWebhook`
The fetched webhook.
"""
adapter: WebhookAdapter = _get_webhook_adapter()
adapter: WebhookAdapter = _context.adapter
if prefer_auth and self.auth_token:
data = adapter.fetch_webhook(self.id, self.auth_token, session=self.session)
elif self.token:
data = adapter.fetch_webhook_with_token(self.id, self.token, session=self.session)
else:
raise InvalidArgument("This webhook does not have a token associated with it")
raise InvalidArgument('This webhook does not have a token associated with it')
return SyncWebhook(data, self.session, token=self.auth_token, state=self._state)
def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> None:
def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True):
"""Deletes this Webhook.
Parameters
@ -685,9 +657,9 @@ class SyncWebhook(BaseWebhook):
This webhook does not have a token associated with it.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument("This webhook does not have a token associated with it")
raise InvalidArgument('This webhook does not have a token associated with it')
adapter: WebhookAdapter = _get_webhook_adapter()
adapter: WebhookAdapter = _context.adapter
if prefer_auth and self.auth_token:
adapter.delete_webhook(self.id, token=self.auth_token, session=self.session, reason=reason)
@ -702,7 +674,7 @@ class SyncWebhook(BaseWebhook):
avatar: Optional[bytes] = MISSING,
channel: Optional[Snowflake] = None,
prefer_auth: bool = True,
) -> SyncWebhook:
):
"""Edits this Webhook.
Parameters
@ -730,51 +702,40 @@ class SyncWebhook(BaseWebhook):
InvalidArgument
This webhook does not have a token associated with it
or it tried editing a channel without authentication.
Returns
--------
:class:`SyncWebhook`
The newly edited webhook.
"""
if self.token is None and self.auth_token is None:
raise InvalidArgument("This webhook does not have a token associated with it")
raise InvalidArgument('This webhook does not have a token associated with it')
payload = {}
if name is not MISSING:
payload["name"] = str(name) if name is not None else None
payload['name'] = str(name) if name is not None else None
if avatar is not MISSING:
payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
payload['avatar'] = utils._bytes_to_base64_data(avatar) if avatar is not None else None
adapter: WebhookAdapter = _get_webhook_adapter()
adapter: WebhookAdapter = _context.adapter
data: Optional[WebhookPayload] = None
# If a channel is given, always use the authenticated endpoint
if channel is not None:
if self.auth_token is None:
raise InvalidArgument("Editing channel requires authenticated webhook")
raise InvalidArgument('Editing channel requires authenticated webhook')
payload["channel_id"] = channel.id
payload['channel_id'] = channel.id
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
self._update(data)
return
if prefer_auth and self.auth_token:
data = adapter.edit_webhook(self.id, self.auth_token, payload=payload, session=self.session, reason=reason)
self._update(data)
elif self.token:
data = adapter.edit_webhook_with_token(
self.id, self.token, payload=payload, session=self.session, reason=reason
)
if data is None:
raise RuntimeError("Unreachable code hit: data was not assigned")
return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state)
data = adapter.edit_webhook_with_token(self.id, self.token, payload=payload, session=self.session, reason=reason)
self._update(data)
def _create_message(self, data):
state = _WebhookState(self, parent=self._state)
# state may be artificial (unlikely at this point...)
channel = self.channel or PartialMessageable(state=self._state, id=int(data["channel_id"])) # type: ignore
# state is artificial
return SyncWebhookMessage(data=data, state=state, channel=channel) # type: ignore
channel = self.channel or Object(id=int(data['channel_id']))
return SyncWebhookMessage(data=data, state=state, channel=channel)
@overload
def send(
@ -782,7 +743,7 @@ class SyncWebhook(BaseWebhook):
content: str = MISSING,
*,
username: str = MISSING,
avatar_url: Any = MISSING,
avatar_url: str = MISSING,
tts: bool = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
@ -799,7 +760,7 @@ class SyncWebhook(BaseWebhook):
content: str = MISSING,
*,
username: str = MISSING,
avatar_url: Any = MISSING,
avatar_url: str = MISSING,
tts: bool = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
@ -815,7 +776,7 @@ class SyncWebhook(BaseWebhook):
content: str = MISSING,
*,
username: str = MISSING,
avatar_url: Any = MISSING,
avatar_url: str = MISSING,
tts: bool = False,
file: File = MISSING,
files: List[File] = MISSING,
@ -847,10 +808,9 @@ class SyncWebhook(BaseWebhook):
username: :class:`str`
The username to send with this message. If no username is provided
then the default username for the webhook is used.
avatar_url: :class:`str`
avatar_url: Union[:class:`str`, :class:`Asset`]
The avatar URL to send with this message. If no avatar URL is provided
then the default avatar for the webhook is used. If this is not a
string then it is explicitly cast using ``str``.
then the default avatar for the webhook is used.
tts: :class:`bool`
Indicates if the message should be sent using text-to-speech.
file: :class:`File`
@ -895,9 +855,9 @@ class SyncWebhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument("This webhook does not have a token associated with it")
raise InvalidArgument('This webhook does not have a token associated with it')
previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
if content is None:
content = MISSING
@ -913,7 +873,7 @@ class SyncWebhook(BaseWebhook):
allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions,
)
adapter: WebhookAdapter = _get_webhook_adapter()
adapter: WebhookAdapter = _context.adapter
thread_id: Optional[int] = None
if thread is not MISSING:
thread_id = thread.id
@ -931,7 +891,7 @@ class SyncWebhook(BaseWebhook):
if wait:
return self._create_message(data)
def fetch_message(self, id: int, /) -> SyncWebhookMessage:
def fetch_message(self, id: int) -> SyncWebhookMessage:
"""Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook.
.. versionadded:: 2.0
@ -959,9 +919,9 @@ class SyncWebhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument("This webhook does not have a token associated with it")
raise InvalidArgument('This webhook does not have a token associated with it')
adapter: WebhookAdapter = _get_webhook_adapter()
adapter: WebhookAdapter = _context.adapter
data = adapter.get_webhook_message(
self.id,
self.token,
@ -980,7 +940,7 @@ class SyncWebhook(BaseWebhook):
file: File = MISSING,
files: List[File] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
) -> SyncWebhookMessage:
):
"""Edits a message owned by this webhook.
This is a lower level interface to :meth:`WebhookMessage.edit` in case
@ -1023,9 +983,9 @@ class SyncWebhook(BaseWebhook):
"""
if self.token is None:
raise InvalidArgument("This webhook does not have a token associated with it")
raise InvalidArgument('This webhook does not have a token associated with it')
previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None)
previous_mentions: Optional[AllowedMentions] = getattr(self._state, 'allowed_mentions', None)
params = handle_message_parameters(
content=content,
file=file,
@ -1035,8 +995,8 @@ class SyncWebhook(BaseWebhook):
allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions,
)
adapter: WebhookAdapter = _get_webhook_adapter()
data = adapter.edit_webhook_message(
adapter: WebhookAdapter = _context.adapter
adapter.edit_webhook_message(
self.id,
self.token,
message_id,
@ -1045,9 +1005,8 @@ class SyncWebhook(BaseWebhook):
multipart=params.multipart,
files=params.files,
)
return self._create_message(data)
def delete_message(self, message_id: int, /) -> None:
def delete_message(self, message_id: int):
"""Deletes a message owned by this webhook.
This is a lower level interface to :meth:`WebhookMessage.delete` in case
@ -1068,9 +1027,9 @@ class SyncWebhook(BaseWebhook):
Deleted a message that is not yours.
"""
if self.token is None:
raise InvalidArgument("This webhook does not have a token associated with it")
raise InvalidArgument('This webhook does not have a token associated with it')
adapter: WebhookAdapter = _get_webhook_adapter()
adapter: WebhookAdapter = _context.adapter
adapter.delete_webhook_message(
self.id,
self.token,

View File

@ -41,12 +41,11 @@ if TYPE_CHECKING:
)
__all__ = (
"WidgetChannel",
"WidgetMember",
"Widget",
'WidgetChannel',
'WidgetMember',
'Widget',
)
class WidgetChannel:
"""Represents a "partial" widget channel.
@ -77,8 +76,7 @@ class WidgetChannel:
position: :class:`int`
The channel's position
"""
__slots__ = ("id", "name", "position")
__slots__ = ('id', 'name', 'position')
def __init__(self, id: int, name: str, position: int) -> None:
self.id: int = id
@ -89,19 +87,18 @@ class WidgetChannel:
return self.name
def __repr__(self) -> str:
return f"<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>"
return f'<WidgetChannel id={self.id} name={self.name!r} position={self.position!r}>'
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the channel."""
return f"<#{self.id}>"
return f'<#{self.id}>'
@property
def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return snowflake_time(self.id)
class WidgetMember(BaseUser):
"""Represents a "partial" member of the widget's guild.
@ -150,37 +147,29 @@ class WidgetMember(BaseUser):
connected_channel: Optional[:class:`WidgetChannel`]
Which channel the member is connected to.
"""
__slots__ = (
"name",
"status",
"nick",
"avatar",
"discriminator",
"id",
"bot",
"activity",
"deafened",
"suppress",
"muted",
"connected_channel",
)
__slots__ = ('name', 'status', 'nick', 'avatar', 'discriminator',
'id', 'bot', 'activity', 'deafened', 'suppress', 'muted',
'connected_channel')
if TYPE_CHECKING:
activity: Optional[Union[BaseActivity, Spotify]]
def __init__(
self, *, state: ConnectionState, data: WidgetMemberPayload, connected_channel: Optional[WidgetChannel] = None
self,
*,
state: ConnectionState,
data: WidgetMemberPayload,
connected_channel: Optional[WidgetChannel] = None
) -> None:
super().__init__(state=state, data=data)
self.nick: Optional[str] = data.get("nick")
self.status: Status = try_enum(Status, data.get("status"))
self.deafened: Optional[bool] = data.get("deaf", False) or data.get("self_deaf", False)
self.muted: Optional[bool] = data.get("mute", False) or data.get("self_mute", False)
self.suppress: Optional[bool] = data.get("suppress", False)
self.nick: Optional[str] = data.get('nick')
self.status: Status = try_enum(Status, data.get('status'))
self.deafened: Optional[bool] = data.get('deaf', False) or data.get('self_deaf', False)
self.muted: Optional[bool] = data.get('mute', False) or data.get('self_mute', False)
self.suppress: Optional[bool] = data.get('suppress', False)
try:
game = data["game"]
game = data['game']
except KeyError:
activity = None
else:
@ -201,7 +190,6 @@ class WidgetMember(BaseUser):
""":class:`str`: Returns the member's display name."""
return self.nick or self.name
class Widget:
"""Represents a :class:`Guild` widget.
@ -239,28 +227,27 @@ class Widget:
retrieved is capped.
"""
__slots__ = ("_state", "channels", "_invite", "id", "members", "name")
__slots__ = ('_state', 'channels', '_invite', 'id', 'members', 'name')
def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None:
self._state = state
self._invite = data["instant_invite"]
self.name: str = data["name"]
self.id: int = int(data["id"])
self._invite = data['instant_invite']
self.name: str = data['name']
self.id: int = int(data['id'])
self.channels: List[WidgetChannel] = []
for channel in data.get("channels", []):
_id = int(channel["id"])
self.channels.append(WidgetChannel(id=_id, name=channel["name"], position=channel["position"]))
for channel in data.get('channels', []):
_id = int(channel['id'])
self.channels.append(WidgetChannel(id=_id, name=channel['name'], position=channel['position']))
self.members: List[WidgetMember] = []
channels = {channel.id: channel for channel in self.channels}
for member in data.get("members", []):
connected_channel = _get_as_snowflake(member, "channel_id")
for member in data.get('members', []):
connected_channel = _get_as_snowflake(member, 'channel_id')
if connected_channel in channels:
connected_channel = channels[connected_channel] # type: ignore
elif connected_channel:
connected_channel = WidgetChannel(id=connected_channel, name="", position=0)
connected_channel = WidgetChannel(id=connected_channel, name='', position=0)
self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) # type: ignore
@ -273,7 +260,7 @@ class Widget:
return False
def __repr__(self) -> str:
return f"<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>"
return f'<Widget id={self.id} name={self.name!r} invite_url={self.invite_url!r}>'
@property
def created_at(self) -> datetime.datetime:

View File

@ -772,7 +772,8 @@ li > blockquote {
/* admonitions */
div.admonition {
padding: 0 0.8em 0.8em 0.8em !important;
padding: 0 0.8em;
padding-bottom: 0.8em;
margin: 0.8em 0;
border-radius: 2.5px;
border-left-width: 6px;
@ -782,7 +783,7 @@ div.admonition {
p.admonition-title {
font-weight: bold;
margin: 0 -0.8rem !important;
margin: 0 -0.8rem;
padding: 0.4rem 0.6rem 0.4rem 2.5rem;
position: relative;
-moz-user-select: none;
@ -1040,18 +1041,12 @@ dl.function > dt,
dl.attribute > dt,
dl.classmethod > dt,
dl.method > dt,
dl.property > dt,
dl.class > dt,
dl.exception > dt {
background-color: var(--api-entry-background);
padding: 1px 10px;
}
/* bug in sphinx: https://github.com/sphinx-doc/sphinx/issues/9384 */
dl.property > dt > span.descname + em.property {
display: none;
}
dd {
margin-top: 0.5em;
margin-bottom: 0.5em;
@ -1147,10 +1142,6 @@ table.docutils tbody tr td ol.last {
margin-bottom: 0;
}
.align-default {
text-align: left !important;
}
/* hide the welcome text */
section#welcome-to-discord-py > h1 {
display: none;

View File

@ -7,6 +7,9 @@
{%- block extrahead %} {% endblock %}
<!-- end extra head -->
<link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">
<link rel="stylesheet" href="{{ pathto('_static/style.css', 1)|e }}" type="text/css" />
<link rel="stylesheet" href="{{ pathto('_static/codeblocks.css', 1) }}" type="text/css" />
<link rel="stylesheet" href="{{ pathto('_static/icons.css', 1)|e }}" type="text/css" />
{%- block css %}
{%- for css in css_files %}
{%- if css|attr("filename") %}
@ -16,9 +19,6 @@
{%- endif %}
{%- endfor %}
{%- endblock %}
<link rel="stylesheet" href="{{ pathto('_static/style.css', 1)|e }}" type="text/css" />
<link rel="stylesheet" href="{{ pathto('_static/codeblocks.css', 1) }}" type="text/css" />
<link rel="stylesheet" href="{{ pathto('_static/icons.css', 1)|e }}" type="text/css" />
{%- block scripts %}
<script id="documentation_options" data-url_root="{{ pathto('', 1) }}" src="{{ pathto('_static/documentation_options.js', 1) }}"></script>
{%- for js in script_files %}

File diff suppressed because it is too large Load Diff

View File

@ -18,45 +18,43 @@ import re
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath(".."))
sys.path.append(os.path.abspath("extensions"))
sys.path.insert(0, os.path.abspath('..'))
sys.path.append(os.path.abspath('extensions'))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
# needs_sphinx = '1.0'
#needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"builder",
"sphinx.ext.autodoc",
"sphinx.ext.extlinks",
"sphinx.ext.intersphinx",
"sphinx.ext.napoleon",
"sphinxcontrib_trio",
"details",
"exception_hierarchy",
"attributetable",
"resourcelinks",
"nitpick_file_ignorer",
'builder',
'sphinx.ext.autodoc',
'sphinx.ext.extlinks',
'sphinx.ext.intersphinx',
'sphinx.ext.napoleon',
'sphinxcontrib_trio',
'details',
'exception_hierarchy',
'attributetable',
'resourcelinks',
'nitpick_file_ignorer',
]
autodoc_member_order = "bysource"
autodoc_typehints = "none"
# maybe consider this?
# napoleon_attr_annotations = False
autodoc_member_order = 'bysource'
autodoc_typehints = 'none'
extlinks = {
"issue": ("https://github.com/iDevision/enhanced-discord.py/issues/%s", "GH-"),
'issue': ('https://github.com/Rapptz/discord.py/issues/%s', 'GH-'),
}
# Links used for cross-referencing stuff in other documentation
intersphinx_mapping = {
"py": ("https://docs.python.org/3", None),
"aio": ("https://docs.aiohttp.org/en/stable/", None),
"req": ("https://docs.python-requests.org/en/latest/", None),
'py': ('https://docs.python.org/3', None),
'aio': ('https://docs.aiohttp.org/en/stable/', None),
'req': ('https://docs.python-requests.org/en/latest/', None)
}
rst_prolog = """
@ -67,20 +65,20 @@ rst_prolog = """
"""
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
templates_path = ['_templates']
# The suffix of source filenames.
source_suffix = ".rst"
source_suffix = '.rst'
# The encoding of source files.
# source_encoding = 'utf-8-sig'
#source_encoding = 'utf-8-sig'
# The master toctree document.
master_doc = "index"
master_doc = 'index'
# General information about the project.
project = "discord.py"
copyright = "2015-present, Rapptz"
project = 'discord.py'
copyright = '2015-present, Rapptz'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
@ -88,15 +86,15 @@ copyright = "2015-present, Rapptz"
#
# The short X.Y version.
version = ""
with open("../discord/__init__.py") as f:
version = ''
with open('../discord/__init__.py') as f:
version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', f.read(), re.MULTILINE).group(1)
# The full version, including alpha/beta/rc tags.
release = version
# This assumes a tag is available for final releases
branch = "master" if version.endswith("a") else "v" + version
branch = 'master' if version.endswith('a') else 'v' + version
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@ -105,49 +103,49 @@ branch = "master" if version.endswith("a") else "v" + version
# Usually you set "language" from the command line for these cases.
language = None
locale_dirs = ["locale/"]
locale_dirs = ['locale/']
gettext_compact = False
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
# today = ''
#today = ''
# Else, today_fmt is used as the format for a strftime call.
# today_fmt = '%B %d, %Y'
#today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ["_build"]
exclude_patterns = ['_build']
# The reST default role (used for this markup: `text`) to use for all
# documents.
# default_role = None
#default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
# add_function_parentheses = True
#add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
# add_module_names = True
#add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
# show_authors = False
#show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "friendly"
pygments_style = 'friendly'
# A list of ignored prefixes for module index sorting.
# modindex_common_prefix = []
#modindex_common_prefix = []
# If true, keep warnings as "system message" paragraphs in the built documents.
# keep_warnings = False
#keep_warnings = False
# Nitpicky mode options
nitpick_ignore_files = [
"migrating_to_async",
"migrating",
"whats_new",
"migrating_to_async",
"migrating",
"whats_new",
]
# -- Options for HTML output ----------------------------------------------
@ -156,20 +154,21 @@ html_experimental_html5_writer = True
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = "basic"
html_theme = 'basic'
html_context = {
"discord_invite": "https://discord.gg/TvqYBrGXEm",
"discord_extensions": [
("discord.ext.commands", "ext/commands"),
("discord.ext.tasks", "ext/tasks"),
],
'discord_invite': 'https://discord.gg/r3sSKJJ',
'discord_extensions': [
('discord.ext.commands', 'ext/commands'),
('discord.ext.tasks', 'ext/tasks'),
],
}
resource_links = {
"discord": "https://discord.gg/TvqYBrGXEm",
"issues": "https://github.com/iDevision/enhanced-discord.py/issues",
"examples": f"https://github.com/iDevision/enhanced-discord.py/tree/{branch}/examples",
'discord': 'https://discord.gg/r3sSKJJ',
'issues': 'https://github.com/Rapptz/discord.py/issues',
'discussions': 'https://github.com/Rapptz/discord.py/discussions',
'examples': f'https://github.com/Rapptz/discord.py/tree/{branch}/examples',
}
# Theme options are theme-specific and customize the look and feel of a theme
@ -179,143 +178,155 @@ resource_links = {
# }
# Add any paths that contain custom themes here, relative to this directory.
# html_theme_path = []
#html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation".
# html_title = None
#html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.
# html_short_title = None
#html_short_title = None
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
# html_logo = None
#html_logo = None
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
html_favicon = "./images/discord_py_logo.ico"
html_favicon = './images/discord_py_logo.ico'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
html_static_path = ['_static']
# Add any extra paths that contain custom files (such as robots.txt or
# .htaccess) here, relative to this directory. These files are copied
# directly to the root of the documentation.
# html_extra_path = []
#html_extra_path = []
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format.
# html_last_updated_fmt = '%b %d, %Y'
#html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
# html_use_smartypants = True
#html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
# html_sidebars = {}
#html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names to
# template names.
# html_additional_pages = {}
#html_additional_pages = {}
# If false, no module index is generated.
# html_domain_indices = True
#html_domain_indices = True
# If false, no index is generated.
# html_use_index = True
#html_use_index = True
# If true, the index is split into individual pages for each letter.
# html_split_index = False
#html_split_index = False
# If true, links to the reST sources are added to the pages.
# html_show_sourcelink = True
#html_show_sourcelink = True
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
# html_show_sphinx = True
#html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
# html_show_copyright = True
#html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will
# contain a <link> tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served.
# html_use_opensearch = ''
#html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml").
# html_file_suffix = None
#html_file_suffix = None
# Language to be used for generating the HTML full-text search index.
# Sphinx supports the following languages:
# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja'
# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr'
# html_search_language = 'en'
#html_search_language = 'en'
# A dictionary with options for the search language support, empty by default.
# Now only 'ja' uses this config value
# html_search_options = {'type': 'default'}
#html_search_options = {'type': 'default'}
# The name of a javascript file (relative to the configuration directory) that
# implements a search results scorer. If empty, the default will be used.
html_search_scorer = "_static/scorer.js"
html_search_scorer = '_static/scorer.js'
html_js_files = ["custom.js", "settings.js", "copy.js", "sidebar.js"]
html_js_files = [
'custom.js',
'settings.js',
'copy.js',
'sidebar.js'
]
# Output file base name for HTML help builder.
htmlhelp_basename = "discord.pydoc"
htmlhelp_basename = 'discord.pydoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Latex figure (float) alignment
#'figure_align': 'htbp',
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
# Latex figure (float) alignment
#'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
("index", "discord.py.tex", "discord.py Documentation", "Rapptz", "manual"),
('index', 'discord.py.tex', 'discord.py Documentation',
'Rapptz', 'manual'),
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
# latex_logo = None
#latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters.
# latex_use_parts = False
#latex_use_parts = False
# If true, show page references after internal links.
# latex_show_pagerefs = False
#latex_show_pagerefs = False
# If true, show URL addresses after external links.
# latex_show_urls = False
#latex_show_urls = False
# Documents to append as an appendix to all manuals.
# latex_appendices = []
#latex_appendices = []
# If false, no module index is generated.
# latex_domain_indices = True
#latex_domain_indices = True
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [("index", "discord.py", "discord.py Documentation", ["Rapptz"], 1)]
man_pages = [
('index', 'discord.py', 'discord.py Documentation',
['Rapptz'], 1)
]
# If true, show URL addresses after external links.
# man_show_urls = False
#man_show_urls = False
# -- Options for Texinfo output -------------------------------------------
@ -324,32 +335,25 @@ man_pages = [("index", "discord.py", "discord.py Documentation", ["Rapptz"], 1)]
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(
"index",
"discord.py",
"discord.py Documentation",
"Rapptz",
"discord.py",
"One line description of project.",
"Miscellaneous",
),
('index', 'discord.py', 'discord.py Documentation',
'Rapptz', 'discord.py', 'One line description of project.',
'Miscellaneous'),
]
# Documents to append as an appendix to all manuals.
# texinfo_appendices = []
#texinfo_appendices = []
# If false, no module index is generated.
# texinfo_domain_indices = True
#texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'.
# texinfo_show_urls = 'footnote'
#texinfo_show_urls = 'footnote'
# If true, do not generate a @detailmenu in the "Top" node's menu.
# texinfo_no_detailmenu = False
#texinfo_no_detailmenu = False
def setup(app):
if app.config.language == "ja":
app.config.intersphinx_mapping["py"] = ("https://docs.python.org/ja/3", None)
app.config.html_context["discord_invite"] = "https://discord.gg/TvqYBrGXEm"
app.config.resource_links["discord"] = "https://discord.gg/TvqYBrGXEm"
if app.config.language == 'ja':
app.config.intersphinx_mapping['py'] = ('https://docs.python.org/ja/3', None)
app.config.html_context['discord_invite'] = 'https://discord.gg/nXzj3dg'
app.config.resource_links['discord'] = 'https://discord.gg/nXzj3dg'

View File

@ -18,31 +18,6 @@ Bot
.. autoclass:: discord.ext.commands.Bot
:members:
:inherited-members:
:exclude-members: after_invoke, before_invoke, check, check_once, command, event, group, listen
.. automethod:: Bot.after_invoke()
:decorator:
.. automethod:: Bot.before_invoke()
:decorator:
.. automethod:: Bot.check()
:decorator:
.. automethod:: Bot.check_once()
:decorator:
.. automethod:: Bot.command(*args, **kwargs)
:decorator:
.. automethod:: Bot.event()
:decorator:
.. automethod:: Bot.group(*args, **kwargs)
:decorator:
.. automethod:: Bot.listen(name=None)
:decorator:
AutoShardedBot
~~~~~~~~~~~~~~~~
@ -109,10 +84,8 @@ Decorators
~~~~~~~~~~~~
.. autofunction:: discord.ext.commands.command
:decorator:
.. autofunction:: discord.ext.commands.group
:decorator:
Command
~~~~~~~~~
@ -122,16 +95,6 @@ Command
.. autoclass:: discord.ext.commands.Command
:members:
:special-members: __call__
:exclude-members: after_invoke, before_invoke, error
.. automethod:: Command.after_invoke()
:decorator:
.. automethod:: Command.before_invoke()
:decorator:
.. automethod:: Command.error()
:decorator:
Group
~~~~~~
@ -141,22 +104,6 @@ Group
.. autoclass:: discord.ext.commands.Group
:members:
:inherited-members:
:exclude-members: after_invoke, before_invoke, command, error, group
.. automethod:: Group.after_invoke()
:decorator:
.. automethod:: Group.before_invoke()
:decorator:
.. automethod:: Group.command(*args, **kwargs)
:decorator:
.. automethod:: Group.error()
:decorator:
.. automethod:: Group.group(*args, **kwargs)
:decorator:
GroupMixin
~~~~~~~~~~~
@ -165,13 +112,6 @@ GroupMixin
.. autoclass:: discord.ext.commands.GroupMixin
:members:
:exclude-members: command, group
.. automethod:: GroupMixin.command(*args, **kwargs)
:decorator:
.. automethod:: GroupMixin.group(*args, **kwargs)
:decorator:
.. _ext_commands_api_cogs:
@ -271,73 +211,44 @@ Enums
Checks
-------
.. autofunction:: discord.ext.commands.check(predicate)
:decorator:
.. autofunction:: discord.ext.commands.check
.. autofunction:: discord.ext.commands.check_any(*checks)
:decorator:
.. autofunction:: discord.ext.commands.check_any
.. autofunction:: discord.ext.commands.has_role(item)
:decorator:
.. autofunction:: discord.ext.commands.has_role
.. autofunction:: discord.ext.commands.has_permissions(**perms)
:decorator:
.. autofunction:: discord.ext.commands.has_permissions
.. autofunction:: discord.ext.commands.has_guild_permissions(**perms)
:decorator:
.. autofunction:: discord.ext.commands.has_guild_permissions
.. autofunction:: discord.ext.commands.has_any_role(*items)
:decorator:
.. autofunction:: discord.ext.commands.has_any_role
.. autofunction:: discord.ext.commands.bot_has_role(item)
:decorator:
.. autofunction:: discord.ext.commands.bot_has_role
.. autofunction:: discord.ext.commands.bot_has_permissions(**perms)
:decorator:
.. autofunction:: discord.ext.commands.bot_has_permissions
.. autofunction:: discord.ext.commands.bot_has_guild_permissions(**perms)
:decorator:
.. autofunction:: discord.ext.commands.bot_has_guild_permissions
.. autofunction:: discord.ext.commands.bot_has_any_role(*items)
:decorator:
.. autofunction:: discord.ext.commands.bot_has_any_role
.. autofunction:: discord.ext.commands.cooldown(rate, per, type=discord.ext.commands.BucketType.default)
:decorator:
.. autofunction:: discord.ext.commands.cooldown
.. autofunction:: discord.ext.commands.dynamic_cooldown(cooldown, type=BucketType.default)
:decorator:
.. autofunction:: discord.ext.commands.max_concurrency
.. autofunction:: discord.ext.commands.max_concurrency(number, per=discord.ext.commands.BucketType.default, *, wait=False)
:decorator:
.. autofunction:: discord.ext.commands.before_invoke
.. autofunction:: discord.ext.commands.before_invoke(coro)
:decorator:
.. autofunction:: discord.ext.commands.after_invoke
.. autofunction:: discord.ext.commands.after_invoke(coro)
:decorator:
.. autofunction:: discord.ext.commands.guild_only
.. autofunction:: discord.ext.commands.guild_only(,)
:decorator:
.. autofunction:: discord.ext.commands.dm_only
.. autofunction:: discord.ext.commands.dm_only(,)
:decorator:
.. autofunction:: discord.ext.commands.is_owner
.. autofunction:: discord.ext.commands.is_owner(,)
:decorator:
.. autofunction:: discord.ext.commands.is_nsfw(,)
:decorator:
.. autofunction:: discord.ext.commands.is_nsfw
.. _ext_commands_api_context:
Cooldown
---------
.. attributetable:: discord.ext.commands.Cooldown
.. autoclass:: discord.ext.commands.Cooldown
:members:
Context
--------
@ -416,12 +327,6 @@ Converters
.. autoclass:: discord.ext.commands.PartialEmojiConverter
:members:
.. autoclass:: discord.ext.commands.ThreadConverter
:members:
.. autoclass:: discord.ext.commands.GuildStickerConverter
:members:
.. autoclass:: discord.ext.commands.clean_content
:members:
@ -429,12 +334,6 @@ Converters
.. autofunction:: discord.ext.commands.run_converters
Option
~~~~~~
.. autoclass:: discord.ext.commands.Option
:members:
Flag Converter
~~~~~~~~~~~~~~~
@ -535,9 +434,6 @@ Exceptions
.. autoexception:: discord.ext.commands.ChannelNotReadable
:members:
.. autoexception:: discord.ext.commands.ThreadNotFound
:members:
.. autoexception:: discord.ext.commands.BadColourArgument
:members:
@ -553,9 +449,6 @@ Exceptions
.. autoexception:: discord.ext.commands.PartialEmojiConversionFailure
:members:
.. autoexception:: discord.ext.commands.GuildStickerNotFound
:members:
.. autoexception:: discord.ext.commands.BadBoolArgument
:members:
@ -631,7 +524,6 @@ Exception Hierarchy
- :exc:`~.commands.BadArgument`
- :exc:`~.commands.MessageNotFound`
- :exc:`~.commands.MemberNotFound`
- :exc:`~.commands.GuildNotFound`
- :exc:`~.commands.UserNotFound`
- :exc:`~.commands.ChannelNotFound`
- :exc:`~.commands.ChannelNotReadable`
@ -639,10 +531,8 @@ Exception Hierarchy
- :exc:`~.commands.RoleNotFound`
- :exc:`~.commands.BadInviteArgument`
- :exc:`~.commands.EmojiNotFound`
- :exc:`~.commands.GuildStickerNotFound`
- :exc:`~.commands.PartialEmojiConversionFailure`
- :exc:`~.commands.BadBoolArgument`
- :exc:`~.commands.ThreadNotFound`
- :exc:`~.commands.FlagError`
- :exc:`~.commands.BadFlagArgument`
- :exc:`~.commands.MissingFlagArgument`

View File

@ -5,7 +5,7 @@
Commands
==========
One of the most appealing aspects of the command extension is how easy it is to define commands and
One of the most appealing aspect of the command extension is how easy it is to define commands and
how you can arbitrarily nest groups and commands to have a rich sub-command system.
Commands are defined by attaching it to a regular Python function. The command is then invoked by the user using a similar
@ -61,7 +61,6 @@ the name to something other than the function would be as simple as doing this:
async def _list(ctx, arg):
pass
Parameters
------------
@ -134,11 +133,6 @@ at all:
Since the ``args`` variable is a :class:`py:tuple`,
you can do anything you would usually do with one.
.. admonition:: Slash Command Only
This functionally is currently not supported by the slash command API, so is turned into
a single ``STRING`` parameter on discord's end which we do our own parsing on.
Keyword-Only Arguments
++++++++++++++++++++++++
@ -185,12 +179,6 @@ know how the command was executed. It contains a lot of useful information:
The context implements the :class:`abc.Messageable` interface, so anything you can do on a :class:`abc.Messageable` you
can do on the :class:`~ext.commands.Context`.
.. admonition:: Slash Command Only
:attr:`.Context.message` will be fake if in a slash command, it is not
recommended to access if :attr:`.Context.interaction` is not None as most
methods will error due to the message not actually existing.
Converters
------------
@ -404,7 +392,6 @@ A lot of discord models work out of the gate as a parameter:
- :class:`Colour`
- :class:`Emoji`
- :class:`PartialEmoji`
- :class:`Thread` (since v2.0)
Having any of these set as the converter will intelligently convert the argument to the appropriate target type you
specify.
@ -412,55 +399,45 @@ specify.
Under the hood, these are implemented by the :ref:`ext_commands_adv_converters` interface. A table of the equivalent
converter is given below:
+--------------------------+-------------------------------------------------+-----------------------------+
| Discord Class | Converter | Supported By Slash Commands |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Object` | :class:`~ext.commands.ObjectConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Member` | :class:`~ext.commands.MemberConverter` | Yes, as type 6 (USER) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`User` | :class:`~ext.commands.UserConverter` | Yes, as type 6 (USER) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Message` | :class:`~ext.commands.MessageConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`PartialMessage` | :class:`~ext.commands.PartialMessageConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`.GuildChannel` | :class:`~ext.commands.GuildChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`TextChannel` | :class:`~ext.commands.TextChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`VoiceChannel` | :class:`~ext.commands.VoiceChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`StageChannel` | :class:`~ext.commands.StageChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`StoreChannel` | :class:`~ext.commands.StoreChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`CategoryChannel` | :class:`~ext.commands.CategoryChannelConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Thread` | :class:`~ext.commands.ThreadConverter` | Yes, as type 7 (CHANNEL) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Invite` | :class:`~ext.commands.InviteConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Guild` | :class:`~ext.commands.GuildConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Role` | :class:`~ext.commands.RoleConverter` | Yes, as type 8 (ROLE) |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Game` | :class:`~ext.commands.GameConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Colour` | :class:`~ext.commands.ColourConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`Emoji` | :class:`~ext.commands.EmojiConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
| :class:`PartialEmoji` | :class:`~ext.commands.PartialEmojiConverter` | Not currently |
+--------------------------+-------------------------------------------------+-----------------------------+
.. admonition:: Slash Command Only
If a slash command is not marked on the table above as supported, it will be sent as type 3 (STRING)
and parsed by normal content parsing, see
`the discord documentation <https://discord.com/developers/docs/interactions/application-commands#application-command-object-application-command-option-type>`_
for all supported types by the API.
+--------------------------+-------------------------------------------------+
| Discord Class | Converter |
+--------------------------+-------------------------------------------------+
| :class:`Object` | :class:`~ext.commands.ObjectConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Member` | :class:`~ext.commands.MemberConverter` |
+--------------------------+-------------------------------------------------+
| :class:`User` | :class:`~ext.commands.UserConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Message` | :class:`~ext.commands.MessageConverter` |
+--------------------------+-------------------------------------------------+
| :class:`PartialMessage` | :class:`~ext.commands.PartialMessageConverter` |
+--------------------------+-------------------------------------------------+
| :class:`.GuildChannel` | :class:`~ext.commands.GuildChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`TextChannel` | :class:`~ext.commands.TextChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`VoiceChannel` | :class:`~ext.commands.VoiceChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`StageChannel` | :class:`~ext.commands.StageChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`StoreChannel` | :class:`~ext.commands.StoreChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`CategoryChannel` | :class:`~ext.commands.CategoryChannelConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Invite` | :class:`~ext.commands.InviteConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Guild` | :class:`~ext.commands.GuildConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Role` | :class:`~ext.commands.RoleConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Game` | :class:`~ext.commands.GameConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Colour` | :class:`~ext.commands.ColourConverter` |
+--------------------------+-------------------------------------------------+
| :class:`Emoji` | :class:`~ext.commands.EmojiConverter` |
+--------------------------+-------------------------------------------------+
| :class:`PartialEmoji` | :class:`~ext.commands.PartialEmojiConverter` |
+--------------------------+-------------------------------------------------+
By providing the converter it allows us to use them as building blocks for another converter:
@ -507,10 +484,6 @@ then a special error is raised, :exc:`~ext.commands.BadUnionArgument`.
Note that any valid converter discussed above can be passed in to the argument list of a :data:`typing.Union`.
.. admonition:: Slash Command Only
These are not currently supported by the Discord API and will be sent as type 3 (STRING)
typing.Optional
^^^^^^^^^^^^^^^^^
@ -704,11 +677,6 @@ In order to customise the flag syntax we also have a few options that can be pas
a command line parser. The syntax is mainly inspired by Discord's search bar input and as a result
all flags need a corresponding value.
.. admonition:: Slash Command Only
As these are built very similar to slash command options, they are converted into options and parsed
back into flags when the slash command is executed.
The flag converter is similar to regular commands and allows you to use most types of converters
(with the exception of :class:`~ext.commands.Greedy`) as the type annotation. Some extra support is added for specific
annotations as described below.

Some files were not shown because too many files have changed in this diff Show More