This commit is contained in:
IAmTomahawkx 2021-09-05 14:33:00 -07:00
commit a23dae8604
10 changed files with 85 additions and 75 deletions

View File

@ -794,13 +794,13 @@ class CustomActivity(BaseActivity):
return hash((self.name, str(self.emoji))) return hash((self.name, str(self.emoji)))
def __str__(self) -> str: def __str__(self) -> str:
if not self.emoji: if self.emoji:
if self.name:
return f'{self.emoji} {self.name}'
return str(self.emoji)
else:
return str(self.name) return str(self.name)
if self.name:
return f'{self.emoji} {self.name}'
return str(self.emoji)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>' return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'

View File

@ -43,7 +43,6 @@ from .context import Context
from . import errors from . import errors
from .help import HelpCommand, DefaultHelpCommand from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog from .cog import Cog
from discord.utils import raise_expected_coro
if TYPE_CHECKING: if TYPE_CHECKING:
import importlib.machinery import importlib.machinery
@ -425,9 +424,11 @@ class BotBase(GroupMixin):
TypeError TypeError
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
return raise_expected_coro( if not asyncio.iscoroutinefunction(coro):
coro, 'The pre-invoke hook must be a coroutine.' raise TypeError('The pre-invoke hook must be a coroutine.')
)
self._before_invoke = coro
return coro
def after_invoke(self, coro: CFT) -> CFT: def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook. r"""A decorator that registers a coroutine as a post-invoke hook.
@ -456,10 +457,11 @@ class BotBase(GroupMixin):
TypeError TypeError
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
return raise_expected_coro( if not asyncio.iscoroutinefunction(coro):
coro, 'The post-invoke hook must be a coroutine.' raise TypeError('The post-invoke hook must be a coroutine.')
)
self._after_invoke = coro
return coro
# listener registration # listener registration

View File

@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import inspect import inspect
@ -62,7 +61,10 @@ T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog") CogT = TypeVar('CogT', bound="Cog")
P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P') if TYPE_CHECKING:
P = ParamSpec('P')
else:
P = TypeVar('P')
class Context(discord.abc.Messageable, Generic[BotT]): class Context(discord.abc.Messageable, Generic[BotT]):

View File

@ -353,14 +353,14 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod @staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
if guild_id is None: if guild_id is not None:
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
guild = ctx.bot.get_guild(guild_id) return guild._resolve_channel(channel_id) # type: ignore
if guild is not None and channel_id is not None: else:
return guild._resolve_channel(channel_id) # type: ignore return None
else: else:
return None return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
@ -754,8 +754,8 @@ class GuildConverter(IDConverter[discord.Guild]):
if result is None: if result is None:
result = discord.utils.get(ctx.bot.guilds, name=argument) result = discord.utils.get(ctx.bot.guilds, name=argument)
if result is None: if result is None:
raise GuildNotFound(argument) raise GuildNotFound(argument)
return result return result
@ -939,7 +939,8 @@ class clean_content(Converter[str]):
def repl(match: re.Match) -> str: def repl(match: re.Match) -> str:
type = match[1] type = match[1]
id = int(match[2]) id = int(match[2])
return transforms[type](id) transformed = transforms[type](id)
return transformed
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
if self.escape_markdown: if self.escape_markdown:

View File

@ -82,7 +82,9 @@ class StringView:
def skip_string(self, string): def skip_string(self, string):
strlen = len(string) strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string: if self.buffer[self.index:self.index + strlen] == string:
return self._return_index(strlen, True) self.previous = self.index
self.index += strlen
return True
return False return False
def read_rest(self): def read_rest(self):
@ -93,7 +95,9 @@ class StringView:
def read(self, n): def read(self, n):
result = self.buffer[self.index:self.index + n] result = self.buffer[self.index:self.index + n]
return self._return_index(n, result) self.previous = self.index
self.index += n
return result
def get(self): def get(self):
try: try:
@ -101,12 +105,9 @@ class StringView:
except IndexError: except IndexError:
result = None result = None
return self._return_index(1, result)
def _return_index(self, arg0, arg1):
self.previous = self.index self.previous = self.index
self.index += arg0 self.index += 1
return arg1 return result
def get_word(self): def get_word(self):
pos = 0 pos = 0

View File

@ -46,9 +46,7 @@ import traceback
from collections.abc import Sequence from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING, raise_expected_coro from discord.utils import MISSING
__all__ = ( __all__ = (
'loop', 'loop',
@ -490,7 +488,11 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._before_loop = coro
return coro
def after_loop(self, coro: FT) -> FT: def after_loop(self, coro: FT) -> FT:
"""A decorator that register a coroutine to be called after the loop finished running. """A decorator that register a coroutine to be called after the loop finished running.
@ -514,7 +516,11 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._after_loop = coro
return coro
def error(self, coro: ET) -> ET: def error(self, coro: ET) -> ET:
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception. """A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
@ -536,7 +542,11 @@ class Loop(Generic[LF]):
TypeError TypeError
The function was not a coroutine. The function was not a coroutine.
""" """
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._error = coro # type: ignore
return coro
def _get_next_sleep_time(self) -> datetime.datetime: def _get_next_sleep_time(self) -> datetime.datetime:
if self._sleep is not MISSING: if self._sleep is not MISSING:
@ -604,7 +614,8 @@ class Loop(Generic[LF]):
) )
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
return sorted(set(ret)) ret = sorted(set(ret)) # de-dupe and sort times
return ret
def change_interval( def change_interval(
self, self,

View File

@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import threading import threading
@ -64,7 +63,10 @@ __all__ = (
CREATE_NO_WINDOW: int CREATE_NO_WINDOW: int
CREATE_NO_WINDOW = 0 if sys.platform != 'win32' else 0x08000000 if sys.platform != 'win32':
CREATE_NO_WINDOW = 0
else:
CREATE_NO_WINDOW = 0x08000000
class AudioSource: class AudioSource:
"""Represents an audio stream. """Represents an audio stream.
@ -524,12 +526,7 @@ class FFmpegOpusAudio(FFmpegAudio):
@staticmethod @staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
exe = ( exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable
executable[:2] + 'probe'
if executable in {'ffmpeg', 'avconv'}
else executable
)
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
output = subprocess.check_output(args, timeout=20) output = subprocess.check_output(args, timeout=20)
codec = bitrate = None codec = bitrate = None

View File

@ -843,4 +843,4 @@ class ThreadMember(Hashable):
The member or ``None`` if not found. The member or ``None`` if not found.
""" """
return await self.thread.guild.get_member(self.id) return self.thread.guild.get_member(self.id)

View File

@ -185,15 +185,15 @@ class Button(Item[V]):
@emoji.setter @emoji.setter
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
if value is None: if value is not None:
self._underlying.emoji = None if isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
elif isinstance(value, str): elif isinstance(value, _EmojiTag):
self._underlying.emoji = PartialEmoji.from_str(value) self._underlying.emoji = value._to_partial()
elif isinstance(value, _EmojiTag): else:
self._underlying.emoji = value._to_partial() raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
else: else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead') self._underlying.emoji = None
@classmethod @classmethod
def from_component(cls: Type[B], button: ButtonComponent) -> B: def from_component(cls: Type[B], button: ButtonComponent) -> B:

View File

@ -499,14 +499,14 @@ else:
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After') reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
if not use_clock and reset_after: if use_clock or not reset_after:
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
else:
return float(reset_after) return float(reset_after)
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
async def maybe_coroutine(f, *args, **kwargs): async def maybe_coroutine(f, *args, **kwargs):
value = f(*args, **kwargs) value = f(*args, **kwargs)
@ -659,10 +659,11 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
if isinstance(invite, Invite): if isinstance(invite, Invite):
return invite.code return invite.code
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' else:
m = re.match(rx, invite) rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
if m: m = re.match(rx, invite)
return m.group(1) if m:
return m.group(1)
return invite return invite
@ -686,10 +687,11 @@ def resolve_template(code: Union[Template, str]) -> str:
if isinstance(code, Template): if isinstance(code, Template):
return code.code return code.code
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' else:
m = re.match(rx, code) rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
if m: m = re.match(rx, code)
return m.group(1) if m:
return m.group(1)
return code return code
@ -1015,9 +1017,3 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
if style is None: if style is None:
return f'<t:{int(dt.timestamp())}>' return f'<t:{int(dt.timestamp())}>'
return f'<t:{int(dt.timestamp())}:{style}>' return f'<t:{int(dt.timestamp())}:{style}>'
def raise_expected_coro(coro, error: str)-> TypeError:
if not asyncio.iscoroutinefunction(coro):
raise TypeError(error)
return coro