[commands] Inject the internal variables for bot.say & co explicitly.
This is to catch cases where it wouldn't fail to find it when inspecting the stack to catch these stack variables.
This commit is contained in:
		| @@ -33,6 +33,16 @@ from .view import StringView | ||||
| from .context import Context | ||||
| from .errors import CommandNotFound | ||||
|  | ||||
| def _get_variable(name): | ||||
|     stack = inspect.stack() | ||||
|     try: | ||||
|         for frames in stack: | ||||
|             current_locals = frames[0].f_locals | ||||
|             if name in current_locals: | ||||
|                 return current_locals[name] | ||||
|     finally: | ||||
|         del stack | ||||
|  | ||||
| def when_mentioned(bot, msg): | ||||
|     """A callable that implements a command prefix equivalent | ||||
|     to being mentioned, e.g. ``@bot ``.""" | ||||
| @@ -71,13 +81,6 @@ class Bot(GroupMixin, discord.Client): | ||||
|  | ||||
|     # internal helpers | ||||
|  | ||||
|     def _get_variable(self, name): | ||||
|         stack = inspect.stack() | ||||
|         for frames in stack: | ||||
|             current_locals = frames[0].f_locals | ||||
|             if name in current_locals: | ||||
|                 return current_locals[name] | ||||
|  | ||||
|     def _get_prefix(self, message): | ||||
|         prefix = self.command_prefix | ||||
|         if callable(prefix): | ||||
| @@ -122,7 +125,7 @@ class Bot(GroupMixin, discord.Client): | ||||
|         content : str | ||||
|             The content to pass to :class:`Client.send_message` | ||||
|         """ | ||||
|         destination = self._get_variable('_internal_channel') | ||||
|         destination = _get_variable('_internal_channel') | ||||
|         result = yield from self.send_message(destination, content) | ||||
|         return result | ||||
|  | ||||
| @@ -141,7 +144,7 @@ class Bot(GroupMixin, discord.Client): | ||||
|         content : str | ||||
|             The content to pass to :class:`Client.send_message` | ||||
|         """ | ||||
|         destination = self._get_variable('_internal_author') | ||||
|         destination = _get_variable('_internal_author') | ||||
|         result = yield from self.send_message(destination, content) | ||||
|         return result | ||||
|  | ||||
| @@ -161,8 +164,8 @@ class Bot(GroupMixin, discord.Client): | ||||
|         content : str | ||||
|             The content to pass to :class:`Client.send_message` | ||||
|         """ | ||||
|         author = self._get_variable('_internal_author') | ||||
|         destination = self._get_variable('_internal_channel') | ||||
|         author = _get_variable('_internal_author') | ||||
|         destination = _get_variable('_internal_channel') | ||||
|         fmt = '{0.mention}, {1}'.format(author, str(content)) | ||||
|         result = yield from self.send_message(destination, fmt) | ||||
|         return result | ||||
| @@ -184,7 +187,7 @@ class Bot(GroupMixin, discord.Client): | ||||
|         name | ||||
|             The second parameter to pass to :meth:`Client.send_file` | ||||
|         """ | ||||
|         destination = self._get_variable('_internal_channel') | ||||
|         destination = _get_variable('_internal_channel') | ||||
|         result = yield from self.send_file(destination, fp, name) | ||||
|         return result | ||||
|  | ||||
| @@ -202,7 +205,7 @@ class Bot(GroupMixin, discord.Client): | ||||
|         --------- | ||||
|         The :meth:`Client.send_typing` function. | ||||
|         """ | ||||
|         destination = self._get_variable('_internal_channel') | ||||
|         destination = _get_variable('_internal_channel') | ||||
|         yield from self.send_typing(destination) | ||||
|  | ||||
|     # listener registration | ||||
|   | ||||
| @@ -28,7 +28,7 @@ import asyncio | ||||
| import inspect | ||||
| import re | ||||
| import discord | ||||
| from functools import partial | ||||
| import functools | ||||
|  | ||||
| from .errors import * | ||||
| from .view import quoted_word | ||||
| @@ -36,6 +36,17 @@ from .view import quoted_word | ||||
| __all__ = [ 'Command', 'Group', 'GroupMixin', 'command', 'group', | ||||
|             'has_role', 'has_permissions', 'has_any_role', 'check' ] | ||||
|  | ||||
| def inject_context(ctx, coro): | ||||
|     @functools.wraps(coro) | ||||
|     @asyncio.coroutine | ||||
|     def wrapped(*args, **kwargs): | ||||
|         _internal_channel = ctx.message.channel | ||||
|         _internal_author = ctx.message.author | ||||
|  | ||||
|         ret = yield from coro(*args, **kwargs) | ||||
|         return ret | ||||
|     return wrapped | ||||
|  | ||||
| def _convert_to_bool(argument): | ||||
|     lowered = argument.lower() | ||||
|     if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): | ||||
| @@ -103,10 +114,11 @@ class Command: | ||||
|         except AttributeError: | ||||
|             return | ||||
|  | ||||
|         injected = inject_context(ctx, coro) | ||||
|         if self.instance is not None: | ||||
|             discord.utils.create_task(coro(self.instance, error, ctx), loop=ctx.bot.loop) | ||||
|             discord.utils.create_task(injected(self.instance, error, ctx), loop=ctx.bot.loop) | ||||
|         else: | ||||
|             discord.utils.create_task(coro(error, ctx), loop=ctx.bot.loop) | ||||
|             discord.utils.create_task(injected(error, ctx), loop=ctx.bot.loop) | ||||
|  | ||||
|     def _receive_item(self, message, argument, regex, receiver, generator): | ||||
|         match = re.match(regex, argument) | ||||
| @@ -263,7 +275,8 @@ class Command: | ||||
|             return | ||||
|  | ||||
|         if self._parse_arguments(ctx): | ||||
|             yield from self.callback(*ctx.args, **ctx.kwargs) | ||||
|             injected = inject_context(ctx, self.callback) | ||||
|             yield from injected(*ctx.args, **ctx.kwargs) | ||||
|  | ||||
|     def error(self, coro): | ||||
|         """A decorator that registers a coroutine as a local error handler. | ||||
| @@ -425,7 +438,8 @@ class Group(GroupMixin, Command): | ||||
|             if trigger in self.commands: | ||||
|                 ctx.invoked_subcommand = self.commands[trigger] | ||||
|  | ||||
|         yield from self.callback(*ctx.args, **ctx.kwargs) | ||||
|         injected = inject_context(ctx, self.callback) | ||||
|         yield from injected(*ctx.args, **ctx.kwargs) | ||||
|  | ||||
|         if ctx.invoked_subcommand: | ||||
|             ctx.invoked_with = trigger | ||||
| @@ -616,7 +630,7 @@ def has_any_role(*names): | ||||
|         if ch.is_private: | ||||
|             return False | ||||
|  | ||||
|         getter = partial(discord.utils.get, msg.author.roles) | ||||
|         getter = functools.partial(discord.utils.get, msg.author.roles) | ||||
|         return any(getter(name=name) is not None for name in names) | ||||
|     return check(predicate) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user