[commands] Add support for typing.Union as a converter
This commit is contained in:
parent
4aecdea052
commit
92dde9aef9
@ -28,6 +28,7 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import discord
|
import discord
|
||||||
import functools
|
import functools
|
||||||
|
import typing
|
||||||
|
|
||||||
from .errors import *
|
from .errors import *
|
||||||
from .cooldowns import Cooldown, BucketType, CooldownMapping
|
from .cooldowns import Cooldown, BucketType, CooldownMapping
|
||||||
@ -212,10 +213,7 @@ class Command:
|
|||||||
self.instance = instance
|
self.instance = instance
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def do_conversion(self, ctx, converter, argument, param):
|
async def _actual_conversion(self, ctx, converter, argument, param):
|
||||||
if converter is bool:
|
|
||||||
return _convert_to_bool(argument)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
module = converter.__module__
|
module = converter.__module__
|
||||||
except:
|
except:
|
||||||
@ -255,6 +253,25 @@ class Command:
|
|||||||
|
|
||||||
raise BadArgument('Converting to "{}" failed for parameter "{}".'.format(name, param.name)) from e
|
raise BadArgument('Converting to "{}" failed for parameter "{}".'.format(name, param.name)) from e
|
||||||
|
|
||||||
|
async def do_conversion(self, ctx, converter, argument, param):
|
||||||
|
if converter is bool:
|
||||||
|
return _convert_to_bool(argument)
|
||||||
|
|
||||||
|
if type(converter) is typing._Union:
|
||||||
|
errors = []
|
||||||
|
for conv in converter.__args__:
|
||||||
|
try:
|
||||||
|
value = await self._actual_conversion(ctx, conv, argument, param)
|
||||||
|
except CommandError as e:
|
||||||
|
errors.append(e)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
# if we're here, then we failed all the converters
|
||||||
|
raise BadUnionArgument(param, converter.__args__, errors)
|
||||||
|
|
||||||
|
return await self._actual_conversion(ctx, converter, argument, param)
|
||||||
|
|
||||||
def _get_converter(self, param):
|
def _get_converter(self, param):
|
||||||
converter = param.annotation
|
converter = param.annotation
|
||||||
if converter is param.empty:
|
if converter is param.empty:
|
||||||
|
@ -30,7 +30,8 @@ __all__ = [ 'CommandError', 'MissingRequiredArgument', 'BadArgument',
|
|||||||
'NoPrivateMessage', 'CheckFailure', 'CommandNotFound',
|
'NoPrivateMessage', 'CheckFailure', 'CommandNotFound',
|
||||||
'DisabledCommand', 'CommandInvokeError', 'TooManyArguments',
|
'DisabledCommand', 'CommandInvokeError', 'TooManyArguments',
|
||||||
'UserInputError', 'CommandOnCooldown', 'NotOwner',
|
'UserInputError', 'CommandOnCooldown', 'NotOwner',
|
||||||
'MissingPermissions', 'BotMissingPermissions', 'ConversionError']
|
'MissingPermissions', 'BotMissingPermissions', 'ConversionError',
|
||||||
|
'BadUnionArgument']
|
||||||
|
|
||||||
class CommandError(DiscordException):
|
class CommandError(DiscordException):
|
||||||
"""The base exception type for all command related errors.
|
"""The base exception type for all command related errors.
|
||||||
@ -191,3 +192,35 @@ class BotMissingPermissions(CheckFailure):
|
|||||||
fmt = ' and '.join(missing)
|
fmt = ' and '.join(missing)
|
||||||
message = 'Bot requires {} permission(s) to run command.'.format(fmt)
|
message = 'Bot requires {} permission(s) to run command.'.format(fmt)
|
||||||
super().__init__(message, *args)
|
super().__init__(message, *args)
|
||||||
|
|
||||||
|
class BadUnionArgument(UserInputError):
|
||||||
|
"""Exception raised when a :class:`typing.Union` converter fails for all
|
||||||
|
its associated types.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
-----------
|
||||||
|
param: :class:`inspect.Parameter`
|
||||||
|
The parameter that failed being converted.
|
||||||
|
converters: Tuple[Type, ...]
|
||||||
|
A tuple of converters attempted in conversion, in order of failure.
|
||||||
|
errors: List[:class:`CommandError`]
|
||||||
|
A list of errors that were caught from failing the conversion.
|
||||||
|
"""
|
||||||
|
def __init__(self, param, converters, errors):
|
||||||
|
self.param = param
|
||||||
|
self.converters = converters
|
||||||
|
self.errors = errors
|
||||||
|
|
||||||
|
def _get_name(x):
|
||||||
|
try:
|
||||||
|
return x.__name__
|
||||||
|
except AttributeError:
|
||||||
|
return x.__class__.__name__
|
||||||
|
|
||||||
|
to_string = [_get_name(x) for x in converters]
|
||||||
|
if len(to_string) > 2:
|
||||||
|
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
|
||||||
|
else:
|
||||||
|
fmt = ' or '.join(to_string)
|
||||||
|
|
||||||
|
super().__init__('Could not convert "{0.name}" into {1}.'.format(param, fmt))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user