[commands] Get guild_id from message link for message converters

This commit is contained in:
Rapptz
2021-07-09 08:19:56 -04:00
parent 0aa825557d
commit 1a4e73d599

View File

@@ -326,22 +326,42 @@ class PartialMessageConverter(Converter[discord.PartialMessage]):
""" """
@staticmethod @staticmethod
def _get_id_matches(argument): def _get_id_matches(ctx, argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$') id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
link_regex = re.compile( link_regex = re.compile(
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/' r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
r'(?:[0-9]{15,20}|@me)' r'(?P<guild_id>[0-9]{15,20}|@me)'
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$' r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$'
) )
match = id_regex.match(argument) or link_regex.match(argument) match = id_regex.match(argument) or link_regex.match(argument)
if not match: if not match:
raise MessageNotFound(argument) raise MessageNotFound(argument)
channel_id = match.group('channel_id') data = match.groupdict()
return int(match.group('message_id')), int(channel_id) if channel_id else None channel_id = discord.utils._get_as_snowflake(data, 'channel_id')
message_id = int(data['message_id'])
guild_id = data.get('guild_id')
if guild_id is None:
guild_id = ctx.guild and ctx.guild.id
elif guild_id == '@me':
guild_id = None
else:
guild_id = int(guild_id)
return guild_id, message_id, channel_id
@staticmethod
def _resolve_channel(ctx, guild_id, channel_id):
if guild_id is not None:
guild = ctx.bot.get_guild(guild_id)
if guild is not None and channel_id is not None:
return guild.get_channel(channel_id)
else:
return None
else:
return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage:
message_id, channel_id = self._get_id_matches(argument) guild_id, message_id, channel_id = self._get_id_matches(ctx, argument)
channel = ctx.bot.get_channel(channel_id) if channel_id else ctx.channel channel = self._resolve_channel(ctx, guild_id, channel_id)
if not channel: if not channel:
raise ChannelNotFound(channel_id) raise ChannelNotFound(channel_id)
return discord.PartialMessage(channel=channel, id=message_id) return discord.PartialMessage(channel=channel, id=message_id)
@@ -363,11 +383,11 @@ class MessageConverter(IDConverter[discord.Message]):
""" """
async def convert(self, ctx: Context, argument: str) -> discord.Message: async def convert(self, ctx: Context, argument: str) -> discord.Message:
message_id, channel_id = PartialMessageConverter._get_id_matches(argument) guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument)
message = ctx.bot._connection._get_message(message_id) message = ctx.bot._connection._get_message(message_id)
if message: if message:
return message return message
channel = ctx.bot.get_channel(channel_id) if channel_id else ctx.channel channel = PartialMessageConverter._resolve_channel(ctx, guild_id, channel_id)
if not channel: if not channel:
raise ChannelNotFound(channel_id) raise ChannelNotFound(channel_id)
try: try: