Merge pull request #12

* Clean up python

* Clean up bot python

* revert lists

* revert commands.bot completely

* extract raise_expected_coro further

* add new lines

* removed erroneous import

* remove hashed line
This commit is contained in:
chillymosh 2021-09-02 20:32:46 +01:00 committed by GitHub
parent 092fbca08f
commit 42c0a8d8a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 74 additions and 84 deletions

View File

@ -794,13 +794,13 @@ class CustomActivity(BaseActivity):
return hash((self.name, str(self.emoji))) return hash((self.name, str(self.emoji)))
def __str__(self) -> str: def __str__(self) -> str:
if self.emoji: if not self.emoji:
if self.name:
return f'{self.emoji} {self.name}'
return str(self.emoji)
else:
return str(self.name) return str(self.name)
if self.name:
return f'{self.emoji} {self.name}'
return str(self.emoji)
def __repr__(self) -> str: def __repr__(self) -> str:
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>' return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>'

View File

@ -43,6 +43,7 @@ from .context import Context
from . import errors from . import errors
from .help import HelpCommand, DefaultHelpCommand from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog from .cog import Cog
from discord.utils import raise_expected_coro
if TYPE_CHECKING: if TYPE_CHECKING:
import importlib.machinery import importlib.machinery
@ -424,11 +425,9 @@ class BotBase(GroupMixin):
TypeError TypeError
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
if not asyncio.iscoroutinefunction(coro): return raise_expected_coro(
raise TypeError('The pre-invoke hook must be a coroutine.') coro, 'The pre-invoke hook must be a coroutine.'
)
self._before_invoke = coro
return coro
def after_invoke(self, coro: CFT) -> CFT: def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook. r"""A decorator that registers a coroutine as a post-invoke hook.
@ -457,11 +456,10 @@ class BotBase(GroupMixin):
TypeError TypeError
The coroutine passed is not actually a coroutine. The coroutine passed is not actually a coroutine.
""" """
if not asyncio.iscoroutinefunction(coro): return raise_expected_coro(
raise TypeError('The post-invoke hook must be a coroutine.') coro, 'The post-invoke hook must be a coroutine.'
)
self._after_invoke = coro
return coro
# listener registration # listener registration

View File

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import inspect import inspect
@ -61,10 +62,7 @@ T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog") CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING: P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P')
P = ParamSpec('P')
else:
P = TypeVar('P')
class Context(discord.abc.Messageable, Generic[BotT]): class Context(discord.abc.Messageable, Generic[BotT]):

View File

@ -353,15 +353,15 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod @staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
if guild_id is not None: if guild_id is None:
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
else:
return None
else:
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
else:
return None
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id) channel = self._resolve_channel(ctx, guild_id, channel_id)
@ -754,8 +754,8 @@ class GuildConverter(IDConverter[discord.Guild]):
if result is None: if result is None:
result = discord.utils.get(ctx.bot.guilds, name=argument) result = discord.utils.get(ctx.bot.guilds, name=argument)
if result is None: if result is None:
raise GuildNotFound(argument) raise GuildNotFound(argument)
return result return result
@ -939,8 +939,7 @@ class clean_content(Converter[str]):
def repl(match: re.Match) -> str: def repl(match: re.Match) -> str:
type = match[1] type = match[1]
id = int(match[2]) id = int(match[2])
transformed = transforms[type](id) return transforms[type](id)
return transformed
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
if self.escape_markdown: if self.escape_markdown:

View File

@ -82,9 +82,7 @@ class StringView:
def skip_string(self, string): def skip_string(self, string):
strlen = len(string) strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string: if self.buffer[self.index:self.index + strlen] == string:
self.previous = self.index return self._return_index(strlen, True)
self.index += strlen
return True
return False return False
def read_rest(self): def read_rest(self):
@ -95,9 +93,7 @@ class StringView:
def read(self, n): def read(self, n):
result = self.buffer[self.index:self.index + n] result = self.buffer[self.index:self.index + n]
self.previous = self.index return self._return_index(n, result)
self.index += n
return result
def get(self): def get(self):
try: try:
@ -105,9 +101,12 @@ class StringView:
except IndexError: except IndexError:
result = None result = None
return self._return_index(1, result)
def _return_index(self, arg0, arg1):
self.previous = self.index self.previous = self.index
self.index += 1 self.index += arg0
return result return arg1
def get_word(self): def get_word(self):
pos = 0 pos = 0

View File

@ -46,7 +46,9 @@ import traceback
from collections.abc import Sequence from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING, raise_expected_coro
__all__ = ( __all__ = (
'loop', 'loop',
@ -488,11 +490,7 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._before_loop = coro
return coro
def after_loop(self, coro: FT) -> FT: def after_loop(self, coro: FT) -> FT:
"""A decorator that register a coroutine to be called after the loop finished running. """A decorator that register a coroutine to be called after the loop finished running.
@ -516,11 +514,7 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._after_loop = coro
return coro
def error(self, coro: ET) -> ET: def error(self, coro: ET) -> ET:
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception. """A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
@ -542,11 +536,7 @@ class Loop(Generic[LF]):
TypeError TypeError
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._error = coro # type: ignore
return coro
def _get_next_sleep_time(self) -> datetime.datetime: def _get_next_sleep_time(self) -> datetime.datetime:
if self._sleep is not MISSING: if self._sleep is not MISSING:
@ -614,8 +604,7 @@ class Loop(Generic[LF]):
) )
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times return sorted(set(ret))
return ret
def change_interval( def change_interval(
self, self,

View File

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
import threading import threading
@ -63,10 +64,7 @@ __all__ = (
CREATE_NO_WINDOW: int CREATE_NO_WINDOW: int
if sys.platform != 'win32': CREATE_NO_WINDOW = 0 if sys.platform != 'win32' else 0x08000000
CREATE_NO_WINDOW = 0
else:
CREATE_NO_WINDOW = 0x08000000
class AudioSource: class AudioSource:
"""Represents an audio stream. """Represents an audio stream.
@ -526,7 +524,12 @@ class FFmpegOpusAudio(FFmpegAudio):
@staticmethod @staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]:
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable exe = (
executable[:2] + 'probe'
if executable in {'ffmpeg', 'avconv'}
else executable
)
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
output = subprocess.check_output(args, timeout=20) output = subprocess.check_output(args, timeout=20)
codec = bitrate = None codec = bitrate = None

View File

@ -185,16 +185,16 @@ class Button(Item[V]):
@emoji.setter @emoji.setter
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
if value is not None: if value is None:
if isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
else:
self._underlying.emoji = None self._underlying.emoji = None
elif isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
@classmethod @classmethod
def from_component(cls: Type[B], button: ButtonComponent) -> B: def from_component(cls: Type[B], button: ButtonComponent) -> B:
return cls( return cls(

View File

@ -499,14 +499,14 @@ else:
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After') reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
if use_clock or not reset_after: if not use_clock and reset_after:
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
else:
return float(reset_after) return float(reset_after)
utc = datetime.timezone.utc
now = datetime.datetime.now(utc)
reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc)
return (reset - now).total_seconds()
async def maybe_coroutine(f, *args, **kwargs): async def maybe_coroutine(f, *args, **kwargs):
value = f(*args, **kwargs) value = f(*args, **kwargs)
@ -659,11 +659,10 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
if isinstance(invite, Invite): if isinstance(invite, Invite):
return invite.code return invite.code
else: rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)'
rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' m = re.match(rx, invite)
m = re.match(rx, invite) if m:
if m: return m.group(1)
return m.group(1)
return invite return invite
@ -687,11 +686,10 @@ def resolve_template(code: Union[Template, str]) -> str:
if isinstance(code, Template): if isinstance(code, Template):
return code.code return code.code
else: rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)'
rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' m = re.match(rx, code)
m = re.match(rx, code) if m:
if m: return m.group(1)
return m.group(1)
return code return code
@ -1017,3 +1015,9 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None)
if style is None: if style is None:
return f'<t:{int(dt.timestamp())}>' return f'<t:{int(dt.timestamp())}>'
return f'<t:{int(dt.timestamp())}:{style}>' return f'<t:{int(dt.timestamp())}:{style}>'
def raise_expected_coro(coro, error: str)-> TypeError:
if not asyncio.iscoroutinefunction(coro):
raise TypeError(error)
return coro