diff --git a/discord/activity.py b/discord/activity.py index cba61f38..51205377 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -794,13 +794,13 @@ class CustomActivity(BaseActivity): return hash((self.name, str(self.emoji))) def __str__(self) -> str: - if not self.emoji: + if self.emoji: + if self.name: + return f'{self.emoji} {self.name}' + return str(self.emoji) + else: return str(self.name) - if self.name: - return f'{self.emoji} {self.name}' - return str(self.emoji) - def __repr__(self) -> str: return f'' diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index b3a1fb57..e03562b6 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -43,7 +43,6 @@ from .context import Context from . import errors from .help import HelpCommand, DefaultHelpCommand from .cog import Cog -from discord.utils import raise_expected_coro if TYPE_CHECKING: import importlib.machinery @@ -425,9 +424,11 @@ class BotBase(GroupMixin): TypeError The coroutine passed is not actually a coroutine. """ - return raise_expected_coro( - coro, 'The pre-invoke hook must be a coroutine.' - ) + if not asyncio.iscoroutinefunction(coro): + raise TypeError('The pre-invoke hook must be a coroutine.') + + self._before_invoke = coro + return coro def after_invoke(self, coro: CFT) -> CFT: r"""A decorator that registers a coroutine as a post-invoke hook. @@ -456,10 +457,11 @@ class BotBase(GroupMixin): TypeError The coroutine passed is not actually a coroutine. """ - return raise_expected_coro( - coro, 'The post-invoke hook must be a coroutine.' - ) + if not asyncio.iscoroutinefunction(coro): + raise TypeError('The post-invoke hook must be a coroutine.') + self._after_invoke = coro + return coro # listener registration diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 158c84ea..fa16c74a 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - from __future__ import annotations import inspect @@ -62,7 +61,10 @@ T = TypeVar('T') BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") CogT = TypeVar('CogT', bound="Cog") -P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P') +if TYPE_CHECKING: + P = ParamSpec('P') +else: + P = TypeVar('P') class Context(discord.abc.Messageable, Generic[BotT]): diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index ce7037a4..5740a188 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -353,14 +353,14 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: - if guild_id is None: - return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel - - guild = ctx.bot.get_guild(guild_id) - if guild is not None and channel_id is not None: - return guild._resolve_channel(channel_id) # type: ignore + if guild_id is not None: + guild = ctx.bot.get_guild(guild_id) + if guild is not None and channel_id is not None: + return guild._resolve_channel(channel_id) # type: ignore + else: + return None else: - return None + return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) @@ -754,8 +754,8 @@ class GuildConverter(IDConverter[discord.Guild]): if result is None: result = discord.utils.get(ctx.bot.guilds, name=argument) - if result is None: - raise GuildNotFound(argument) + if result is None: + raise GuildNotFound(argument) return result @@ -939,7 +939,8 @@ class clean_content(Converter[str]): def repl(match: re.Match) -> str: type = match[1] id = int(match[2]) - return transforms[type](id) + transformed = transforms[type](id) + return transformed result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) if self.escape_markdown: diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index 39cc35f7..a7dc7236 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -82,7 +82,9 @@ class StringView: def skip_string(self, string): strlen = len(string) if self.buffer[self.index:self.index + strlen] == string: - return self._return_index(strlen, True) + self.previous = self.index + self.index += strlen + return True return False def read_rest(self): @@ -93,7 +95,9 @@ class StringView: def read(self, n): result = self.buffer[self.index:self.index + n] - return self._return_index(n, result) + self.previous = self.index + self.index += n + return result def get(self): try: @@ -101,12 +105,9 @@ class StringView: except IndexError: result = None - return self._return_index(1, result) - - def _return_index(self, arg0, arg1): self.previous = self.index - self.index += arg0 - return arg1 + self.index += 1 + return result def get_word(self): pos = 0 diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 9518390e..5b78f10e 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -46,9 +46,7 @@ import traceback from collections.abc import Sequence from discord.backoff import ExponentialBackoff -from discord.utils import MISSING, raise_expected_coro - - +from discord.utils import MISSING __all__ = ( 'loop', @@ -490,7 +488,11 @@ class Loop(Generic[LF]): The function was not a coroutine. """ - return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') + if not inspect.iscoroutinefunction(coro): + raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + + self._before_loop = coro + return coro def after_loop(self, coro: FT) -> FT: """A decorator that register a coroutine to be called after the loop finished running. @@ -514,7 +516,11 @@ class Loop(Generic[LF]): The function was not a coroutine. """ - return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') + if not inspect.iscoroutinefunction(coro): + raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + + self._after_loop = coro + return coro def error(self, coro: ET) -> ET: """A decorator that registers a coroutine to be called if the task encounters an unhandled exception. @@ -536,7 +542,11 @@ class Loop(Generic[LF]): TypeError The function was not a coroutine. """ - return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') + if not inspect.iscoroutinefunction(coro): + raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + + self._error = coro # type: ignore + return coro def _get_next_sleep_time(self) -> datetime.datetime: if self._sleep is not MISSING: @@ -604,7 +614,8 @@ class Loop(Generic[LF]): ) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) - return sorted(set(ret)) + ret = sorted(set(ret)) # de-dupe and sort times + return ret def change_interval( self, diff --git a/discord/player.py b/discord/player.py index 79579c8d..8098d3e3 100644 --- a/discord/player.py +++ b/discord/player.py @@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - from __future__ import annotations import threading @@ -64,7 +63,10 @@ __all__ = ( CREATE_NO_WINDOW: int -CREATE_NO_WINDOW = 0 if sys.platform != 'win32' else 0x08000000 +if sys.platform != 'win32': + CREATE_NO_WINDOW = 0 +else: + CREATE_NO_WINDOW = 0x08000000 class AudioSource: """Represents an audio stream. @@ -524,12 +526,7 @@ class FFmpegOpusAudio(FFmpegAudio): @staticmethod def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: - exe = ( - executable[:2] + 'probe' - if executable in {'ffmpeg', 'avconv'} - else executable - ) - + exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] output = subprocess.check_output(args, timeout=20) codec = bitrate = None diff --git a/discord/ui/button.py b/discord/ui/button.py index 0b16e87e..fedeac68 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -185,15 +185,15 @@ class Button(Item[V]): @emoji.setter def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore - if value is None: - self._underlying.emoji = None - - elif isinstance(value, str): - self._underlying.emoji = PartialEmoji.from_str(value) - elif isinstance(value, _EmojiTag): - self._underlying.emoji = value._to_partial() + if value is not None: + if isinstance(value, str): + self._underlying.emoji = PartialEmoji.from_str(value) + elif isinstance(value, _EmojiTag): + self._underlying.emoji = value._to_partial() + else: + raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead') else: - raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead') + self._underlying.emoji = None @classmethod def from_component(cls: Type[B], button: ButtonComponent) -> B: diff --git a/discord/utils.py b/discord/utils.py index cad99da6..4360b77a 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -499,14 +499,14 @@ else: def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After') - if not use_clock and reset_after: + if use_clock or not reset_after: + utc = datetime.timezone.utc + now = datetime.datetime.now(utc) + reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc) + return (reset - now).total_seconds() + else: return float(reset_after) - utc = datetime.timezone.utc - now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc) - return (reset - now).total_seconds() - async def maybe_coroutine(f, *args, **kwargs): value = f(*args, **kwargs) @@ -659,10 +659,11 @@ def resolve_invite(invite: Union[Invite, str]) -> str: if isinstance(invite, Invite): return invite.code - rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' - m = re.match(rx, invite) - if m: - return m.group(1) + else: + rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' + m = re.match(rx, invite) + if m: + return m.group(1) return invite @@ -686,10 +687,11 @@ def resolve_template(code: Union[Template, str]) -> str: if isinstance(code, Template): return code.code - rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' - m = re.match(rx, code) - if m: - return m.group(1) + else: + rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' + m = re.match(rx, code) + if m: + return m.group(1) return code @@ -1015,9 +1017,3 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) if style is None: return f'' return f'' - - -def raise_expected_coro(coro, error: str)-> TypeError: - if not asyncio.iscoroutinefunction(coro): - raise TypeError(error) - return coro