Merge pull request #12

* Clean up python

* Clean up bot python

* revert lists

* revert commands.bot completely

* extract raise_expected_coro further

* add new lines

* removed erroneous import

* remove hashed line
This commit is contained in:
chillymosh
2021-09-02 20:32:46 +01:00
committed by GitHub
parent 092fbca08f
commit 42c0a8d8a5
9 changed files with 74 additions and 84 deletions

View File

@ -43,6 +43,7 @@ from .context import Context
from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
from discord.utils import raise_expected_coro
if TYPE_CHECKING:
import importlib.machinery
@ -424,11 +425,9 @@ class BotBase(GroupMixin):
TypeError
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The pre-invoke hook must be a coroutine.')
self._before_invoke = coro
return coro
return raise_expected_coro(
coro, 'The pre-invoke hook must be a coroutine.'
)
def after_invoke(self, coro: CFT) -> CFT:
r"""A decorator that registers a coroutine as a post-invoke hook.
@ -457,11 +456,10 @@ class BotBase(GroupMixin):
TypeError
The coroutine passed is not actually a coroutine.
"""
if not asyncio.iscoroutinefunction(coro):
raise TypeError('The post-invoke hook must be a coroutine.')
return raise_expected_coro(
coro, 'The post-invoke hook must be a coroutine.'
)
self._after_invoke = coro
return coro
# listener registration

View File

@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
import inspect
@ -61,10 +62,7 @@ T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING:
P = ParamSpec('P')
else:
P = TypeVar('P')
P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P')
class Context(discord.abc.Messageable, Generic[BotT]):

View File

@ -353,15 +353,15 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
@staticmethod
def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]:
if guild_id is not None:
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
else:
return None
else:
if guild_id is None:
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild._resolve_channel(channel_id) # type: ignore
else:
return None
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = self._resolve_channel(ctx, guild_id, channel_id)
@ -754,8 +754,8 @@ class GuildConverter(IDConverter[discord.Guild]):
if result is None:
result = discord.utils.get(ctx.bot.guilds, name=argument)
if result is None:
raise GuildNotFound(argument)
if result is None:
raise GuildNotFound(argument)
return result
@ -939,8 +939,7 @@ class clean_content(Converter[str]):
def repl(match: re.Match) -> str:
type = match[1]
id = int(match[2])
transformed = transforms[type](id)
return transformed
return transforms[type](id)
result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument)
if self.escape_markdown:

View File

@ -82,9 +82,7 @@ class StringView:
def skip_string(self, string):
strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string:
self.previous = self.index
self.index += strlen
return True
return self._return_index(strlen, True)
return False
def read_rest(self):
@ -95,9 +93,7 @@ class StringView:
def read(self, n):
result = self.buffer[self.index:self.index + n]
self.previous = self.index
self.index += n
return result
return self._return_index(n, result)
def get(self):
try:
@ -105,9 +101,12 @@ class StringView:
except IndexError:
result = None
return self._return_index(1, result)
def _return_index(self, arg0, arg1):
self.previous = self.index
self.index += 1
return result
self.index += arg0
return arg1
def get_word(self):
pos = 0

View File

@ -46,7 +46,9 @@ import traceback
from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
from discord.utils import MISSING, raise_expected_coro
__all__ = (
'loop',
@ -488,11 +490,7 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._before_loop = coro
return coro
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
def after_loop(self, coro: FT) -> FT:
"""A decorator that register a coroutine to be called after the loop finished running.
@ -516,11 +514,7 @@ class Loop(Generic[LF]):
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._after_loop = coro
return coro
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
def error(self, coro: ET) -> ET:
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
@ -542,11 +536,7 @@ class Loop(Generic[LF]):
TypeError
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
self._error = coro # type: ignore
return coro
return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.')
def _get_next_sleep_time(self) -> datetime.datetime:
if self._sleep is not MISSING:
@ -614,8 +604,7 @@ class Loop(Generic[LF]):
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times
return ret
return sorted(set(ret))
def change_interval(
self,