mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-11-03 23:12:56 +00:00 
			
		
		
		
	[commands] Inject the internal variables for bot.say & co explicitly.
This is to catch cases where it wouldn't fail to find it when inspecting the stack to catch these stack variables.
This commit is contained in:
		@@ -33,6 +33,16 @@ from .view import StringView
 | 
			
		||||
from .context import Context
 | 
			
		||||
from .errors import CommandNotFound
 | 
			
		||||
 | 
			
		||||
def _get_variable(name):
 | 
			
		||||
    stack = inspect.stack()
 | 
			
		||||
    try:
 | 
			
		||||
        for frames in stack:
 | 
			
		||||
            current_locals = frames[0].f_locals
 | 
			
		||||
            if name in current_locals:
 | 
			
		||||
                return current_locals[name]
 | 
			
		||||
    finally:
 | 
			
		||||
        del stack
 | 
			
		||||
 | 
			
		||||
def when_mentioned(bot, msg):
 | 
			
		||||
    """A callable that implements a command prefix equivalent
 | 
			
		||||
    to being mentioned, e.g. ``@bot ``."""
 | 
			
		||||
@@ -71,13 +81,6 @@ class Bot(GroupMixin, discord.Client):
 | 
			
		||||
 | 
			
		||||
    # internal helpers
 | 
			
		||||
 | 
			
		||||
    def _get_variable(self, name):
 | 
			
		||||
        stack = inspect.stack()
 | 
			
		||||
        for frames in stack:
 | 
			
		||||
            current_locals = frames[0].f_locals
 | 
			
		||||
            if name in current_locals:
 | 
			
		||||
                return current_locals[name]
 | 
			
		||||
 | 
			
		||||
    def _get_prefix(self, message):
 | 
			
		||||
        prefix = self.command_prefix
 | 
			
		||||
        if callable(prefix):
 | 
			
		||||
@@ -122,7 +125,7 @@ class Bot(GroupMixin, discord.Client):
 | 
			
		||||
        content : str
 | 
			
		||||
            The content to pass to :class:`Client.send_message`
 | 
			
		||||
        """
 | 
			
		||||
        destination = self._get_variable('_internal_channel')
 | 
			
		||||
        destination = _get_variable('_internal_channel')
 | 
			
		||||
        result = yield from self.send_message(destination, content)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
@@ -141,7 +144,7 @@ class Bot(GroupMixin, discord.Client):
 | 
			
		||||
        content : str
 | 
			
		||||
            The content to pass to :class:`Client.send_message`
 | 
			
		||||
        """
 | 
			
		||||
        destination = self._get_variable('_internal_author')
 | 
			
		||||
        destination = _get_variable('_internal_author')
 | 
			
		||||
        result = yield from self.send_message(destination, content)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
@@ -161,8 +164,8 @@ class Bot(GroupMixin, discord.Client):
 | 
			
		||||
        content : str
 | 
			
		||||
            The content to pass to :class:`Client.send_message`
 | 
			
		||||
        """
 | 
			
		||||
        author = self._get_variable('_internal_author')
 | 
			
		||||
        destination = self._get_variable('_internal_channel')
 | 
			
		||||
        author = _get_variable('_internal_author')
 | 
			
		||||
        destination = _get_variable('_internal_channel')
 | 
			
		||||
        fmt = '{0.mention}, {1}'.format(author, str(content))
 | 
			
		||||
        result = yield from self.send_message(destination, fmt)
 | 
			
		||||
        return result
 | 
			
		||||
@@ -184,7 +187,7 @@ class Bot(GroupMixin, discord.Client):
 | 
			
		||||
        name
 | 
			
		||||
            The second parameter to pass to :meth:`Client.send_file`
 | 
			
		||||
        """
 | 
			
		||||
        destination = self._get_variable('_internal_channel')
 | 
			
		||||
        destination = _get_variable('_internal_channel')
 | 
			
		||||
        result = yield from self.send_file(destination, fp, name)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
@@ -202,7 +205,7 @@ class Bot(GroupMixin, discord.Client):
 | 
			
		||||
        ---------
 | 
			
		||||
        The :meth:`Client.send_typing` function.
 | 
			
		||||
        """
 | 
			
		||||
        destination = self._get_variable('_internal_channel')
 | 
			
		||||
        destination = _get_variable('_internal_channel')
 | 
			
		||||
        yield from self.send_typing(destination)
 | 
			
		||||
 | 
			
		||||
    # listener registration
 | 
			
		||||
 
 | 
			
		||||
@@ -28,7 +28,7 @@ import asyncio
 | 
			
		||||
import inspect
 | 
			
		||||
import re
 | 
			
		||||
import discord
 | 
			
		||||
from functools import partial
 | 
			
		||||
import functools
 | 
			
		||||
 | 
			
		||||
from .errors import *
 | 
			
		||||
from .view import quoted_word
 | 
			
		||||
@@ -36,6 +36,17 @@ from .view import quoted_word
 | 
			
		||||
__all__ = [ 'Command', 'Group', 'GroupMixin', 'command', 'group',
 | 
			
		||||
            'has_role', 'has_permissions', 'has_any_role', 'check' ]
 | 
			
		||||
 | 
			
		||||
def inject_context(ctx, coro):
 | 
			
		||||
    @functools.wraps(coro)
 | 
			
		||||
    @asyncio.coroutine
 | 
			
		||||
    def wrapped(*args, **kwargs):
 | 
			
		||||
        _internal_channel = ctx.message.channel
 | 
			
		||||
        _internal_author = ctx.message.author
 | 
			
		||||
 | 
			
		||||
        ret = yield from coro(*args, **kwargs)
 | 
			
		||||
        return ret
 | 
			
		||||
    return wrapped
 | 
			
		||||
 | 
			
		||||
def _convert_to_bool(argument):
 | 
			
		||||
    lowered = argument.lower()
 | 
			
		||||
    if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'):
 | 
			
		||||
@@ -103,10 +114,11 @@ class Command:
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        injected = inject_context(ctx, coro)
 | 
			
		||||
        if self.instance is not None:
 | 
			
		||||
            discord.utils.create_task(coro(self.instance, error, ctx), loop=ctx.bot.loop)
 | 
			
		||||
            discord.utils.create_task(injected(self.instance, error, ctx), loop=ctx.bot.loop)
 | 
			
		||||
        else:
 | 
			
		||||
            discord.utils.create_task(coro(error, ctx), loop=ctx.bot.loop)
 | 
			
		||||
            discord.utils.create_task(injected(error, ctx), loop=ctx.bot.loop)
 | 
			
		||||
 | 
			
		||||
    def _receive_item(self, message, argument, regex, receiver, generator):
 | 
			
		||||
        match = re.match(regex, argument)
 | 
			
		||||
@@ -263,7 +275,8 @@ class Command:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if self._parse_arguments(ctx):
 | 
			
		||||
            yield from self.callback(*ctx.args, **ctx.kwargs)
 | 
			
		||||
            injected = inject_context(ctx, self.callback)
 | 
			
		||||
            yield from injected(*ctx.args, **ctx.kwargs)
 | 
			
		||||
 | 
			
		||||
    def error(self, coro):
 | 
			
		||||
        """A decorator that registers a coroutine as a local error handler.
 | 
			
		||||
@@ -425,7 +438,8 @@ class Group(GroupMixin, Command):
 | 
			
		||||
            if trigger in self.commands:
 | 
			
		||||
                ctx.invoked_subcommand = self.commands[trigger]
 | 
			
		||||
 | 
			
		||||
        yield from self.callback(*ctx.args, **ctx.kwargs)
 | 
			
		||||
        injected = inject_context(ctx, self.callback)
 | 
			
		||||
        yield from injected(*ctx.args, **ctx.kwargs)
 | 
			
		||||
 | 
			
		||||
        if ctx.invoked_subcommand:
 | 
			
		||||
            ctx.invoked_with = trigger
 | 
			
		||||
@@ -616,7 +630,7 @@ def has_any_role(*names):
 | 
			
		||||
        if ch.is_private:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        getter = partial(discord.utils.get, msg.author.roles)
 | 
			
		||||
        getter = functools.partial(discord.utils.get, msg.author.roles)
 | 
			
		||||
        return any(getter(name=name) is not None for name in names)
 | 
			
		||||
    return check(predicate)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user