This commit is contained in:
IAmTomahawkx 2021-09-05 14:33:00 -07:00
commit a23dae8604
10 changed files with 85 additions and 75 deletions

View File

@ -794,13 +794,13 @@ class CustomActivity(BaseActivity):
return hash((self.name, str(self.emoji)))
def __str__(self) -> str:
if not self.emoji:
if self.emoji:
if self.name:
return f'{self.emoji} {self.name}'
return str(self.emoji)
else:
return str(self.name)
if self.name:
return f'{self.emoji} {self.name}'
return str(self.emoji)
def __repr__(self) -> str:
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'

View File

@ -43,7 +43,6 @@ from .context import Context
from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
from discord.utils import raise_expected_coro
if TYPE_CHECKING:
import importlib.machinery
@ -425,9 +424,11 @@ class BotBase(GroupMixin):
TypeError
The coroutine passed is not actually a coroutine.
"""
return raise_expected_coro(
coro, 'The pre-invoke hook must be a coroutine.'
)
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.')
self._before_invoke = coro
return coro
def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook.
@ -456,10 +457,11 @@ class BotBase(GroupMixin):
TypeError
The coroutine passed is not actually a coroutine.
"""
return raise_expected_coro(
coro, 'The post-invoke hook must be a coroutine.'
)
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.')
self._after_invoke = coro
return coro
# listener registration

View File

@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import inspect
@ -62,7 +61,10 @@ T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P')
if TYPE_CHECKING:
P = ParamSpec('P')
else:
P = TypeVar('P')
class Context(discord.abc.Messageable, Generic[BotT]):

View File

@ -353,14 +353,14 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
if guild_id is None:
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
if guild_id is not None:
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
else:
return None
else:
return None
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
@ -754,8 +754,8 @@ class GuildConverter(IDConverter[discord.Guild]):
if result is None:
result = discord.utils.get(ctx.bot.guilds, name=argument)
if result is None:
raise GuildNotFound(argument)
if result is None:
raise GuildNotFound(argument)
return result
@ -939,7 +939,8 @@ class clean_content(Converter[str]):
def repl(match: re.Match) -> str:
type = match[1]
id = int(match[2])
return transforms[type](id)
transformed = transforms[type](id)
return transformed
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
if self.escape_markdown:

View File

@ -82,7 +82,9 @@ class StringView:
def skip_string(self, string):
strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string:
return self._return_index(strlen, True)
self.previous = self.index
self.index += strlen
return True
return False
def read_rest(self):
@ -93,7 +95,9 @@ class StringView:
def read(self, n):
result = self.buffer[self.index:self.index + n]
return self._return_index(n, result)
self.previous = self.index
self.index += n
return result
def get(self):
try:
@ -101,12 +105,9 @@ class StringView:
except IndexError:
result = None
return self._return_index(1, result)
def _return_index(self, arg0, arg1):
self.previous = self.index
self.index += arg0
return arg1
self.index += 1
return result
def get_word(self):
pos = 0

View File

@ -46,9 +46,7 @@ import traceback
from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING, raise_expected_coro
from discord.utils import MISSING
__all__ = (
'loop',
@ -490,7 +488,11 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._before_loop = coro
return coro
def after_loop(self, coro: FT) -> FT:
"""A decorator that register a coroutine to be called after the loop finished running.
@ -514,7 +516,11 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._after_loop = coro
return coro
def error(self, coro: ET) -> ET:
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
@ -536,7 +542,11 @@ class Loop(Generic[LF]):
TypeError
The function was not a coroutine.
"""
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._error = coro # type: ignore
return coro
def _get_next_sleep_time(self) -> datetime.datetime:
if self._sleep is not MISSING:
@ -604,7 +614,8 @@ class Loop(Generic[LF]):
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
return sorted(set(ret))
ret = sorted(set(ret)) # de-dupe and sort times
return ret
def change_interval(
self,

View File

@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import threading
@ -64,7 +63,10 @@ __all__ = (
CREATE_NO_WINDOW: int
CREATE_NO_WINDOW = 0 if sys.platform != 'win32' else 0x08000000
if sys.platform != 'win32':
CREATE_NO_WINDOW = 0
else:
CREATE_NO_WINDOW = 0x08000000
class AudioSource:
"""Represents an audio stream.
@ -524,12 +526,7 @@ class FFmpegOpusAudio(FFmpegAudio):
@staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
exe = (
executable[:2] + 'probe'
if executable in {'ffmpeg', 'avconv'}
else executable
)
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
output = subprocess.check_output(args, timeout=20)
codec = bitrate = None

View File

@ -843,4 +843,4 @@ class ThreadMember(Hashable):
The member or ``None`` if not found.
"""
return await self.thread.guild.get_member(self.id)
return self.thread.guild.get_member(self.id)

View File

@ -185,15 +185,15 @@ class Button(Item[V]):
@emoji.setter
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
if value is None:
self._underlying.emoji = None
elif isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
if value is not None:
if isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
self._underlying.emoji = None
@classmethod
def from_component(cls: Type[B], button: ButtonComponent) -> B:

View File

@ -499,14 +499,14 @@ else:
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
if not use_clock and reset_after:
if use_clock or not reset_after:
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
else:
return float(reset_after)
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
async def maybe_coroutine(f, *args, **kwargs):
value = f(*args, **kwargs)
@ -659,10 +659,11 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
if isinstance(invite, Invite):
return invite.code
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
m = re.match(rx, invite)
if m:
return m.group(1)
else:
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
m = re.match(rx, invite)
if m:
return m.group(1)
return invite
@ -686,10 +687,11 @@ def resolve_template(code: Union[Template, str]) -> str:
if isinstance(code, Template):
return code.code
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
m = re.match(rx, code)
if m:
return m.group(1)
else:
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
m = re.match(rx, code)
if m:
return m.group(1)
return code
@ -1015,9 +1017,3 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
if style is None:
return f'<t:{int(dt.timestamp())}>'
return f'<t:{int(dt.timestamp())}:{style}>'
def raise_expected_coro(coro, error: str)-> TypeError:
if not asyncio.iscoroutinefunction(coro):
raise TypeError(error)
return coro