[commands]Add typing.Literal converter
This commit is contained in:
@ -489,31 +489,52 @@ class Command(_BaseCommand):
|
||||
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
|
||||
|
||||
async def do_conversion(self, ctx, converter, argument, param):
|
||||
try:
|
||||
origin = converter.__origin__
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if origin is typing.Union:
|
||||
errors = []
|
||||
_NoneType = type(None)
|
||||
for conv in converter.__args__:
|
||||
# if we got to this part in the code, then the previous conversions have failed
|
||||
# so we should just undo the view, return the default, and allow parsing to continue
|
||||
# with the other parameters
|
||||
if conv is _NoneType and param.kind != param.VAR_POSITIONAL:
|
||||
ctx.view.undo()
|
||||
return None if param.default is param.empty else param.default
|
||||
origin = typing.get_origin(converter)
|
||||
|
||||
if origin is typing.Union:
|
||||
errors = []
|
||||
_NoneType = type(None)
|
||||
for conv in typing.get_args(converter):
|
||||
# if we got to this part in the code, then the previous conversions have failed
|
||||
# so we should just undo the view, return the default, and allow parsing to continue
|
||||
# with the other parameters
|
||||
if conv is _NoneType and param.kind != param.VAR_POSITIONAL:
|
||||
ctx.view.undo()
|
||||
return None if param.default is param.empty else param.default
|
||||
|
||||
try:
|
||||
value = await self.do_conversion(ctx, conv, argument, param)
|
||||
except CommandError as exc:
|
||||
errors.append(exc)
|
||||
else:
|
||||
return value
|
||||
|
||||
# if we're here, then we failed all the converters
|
||||
raise BadUnionArgument(param, typing.get_args(converter), errors)
|
||||
|
||||
if origin is typing.Literal:
|
||||
errors = []
|
||||
conversions = {}
|
||||
literal_args = tuple(self._flattened_typing_literal_args(converter))
|
||||
for literal in literal_args:
|
||||
literal_type = type(literal)
|
||||
try:
|
||||
value = conversions[literal_type]
|
||||
except KeyError:
|
||||
try:
|
||||
value = await self._actual_conversion(ctx, conv, argument, param)
|
||||
value = await self._actual_conversion(ctx, literal_type, argument, param)
|
||||
except CommandError as exc:
|
||||
errors.append(exc)
|
||||
conversions[literal_type] = object()
|
||||
continue
|
||||
else:
|
||||
return value
|
||||
conversions[literal_type] = value
|
||||
|
||||
# if we're here, then we failed all the converters
|
||||
raise BadUnionArgument(param, converter.__args__, errors)
|
||||
if value == literal:
|
||||
return value
|
||||
|
||||
# if we're here, then we failed to match all the literals
|
||||
raise BadLiteralArgument(param, literal_args, errors)
|
||||
|
||||
return await self._actual_conversion(ctx, converter, argument, param)
|
||||
|
||||
@ -995,15 +1016,14 @@ class Command(_BaseCommand):
|
||||
return ''
|
||||
|
||||
def _is_typing_optional(self, annotation):
|
||||
try:
|
||||
origin = annotation.__origin__
|
||||
except AttributeError:
|
||||
return False
|
||||
return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None)
|
||||
|
||||
if origin is not typing.Union:
|
||||
return False
|
||||
|
||||
return annotation.__args__[-1] is type(None)
|
||||
def _flattened_typing_literal_args(self, annotation):
|
||||
for literal in typing.get_args(annotation):
|
||||
if typing.get_origin(literal) is typing.Literal:
|
||||
yield from self._flattened_typing_literal_args(literal)
|
||||
else:
|
||||
yield literal
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
@ -1011,7 +1031,6 @@ class Command(_BaseCommand):
|
||||
if self.usage is not None:
|
||||
return self.usage
|
||||
|
||||
|
||||
params = self.clean_params
|
||||
if not params:
|
||||
return ''
|
||||
@ -1019,6 +1038,22 @@ class Command(_BaseCommand):
|
||||
result = []
|
||||
for name, param in params.items():
|
||||
greedy = isinstance(param.annotation, converters._Greedy)
|
||||
optional = False # postpone evaluation of if it's an optional argument
|
||||
|
||||
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
|
||||
# parameter signature is a literal list of it's values
|
||||
annotation = param.annotation.converter if greedy else param.annotation
|
||||
origin = typing.get_origin(annotation)
|
||||
if not greedy and origin is typing.Union:
|
||||
union_args = typing.get_args(annotation)
|
||||
optional = union_args[-1] is type(None)
|
||||
if optional:
|
||||
annotation = union_args[0]
|
||||
origin = typing.get_origin(annotation)
|
||||
|
||||
if origin is typing.Literal:
|
||||
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
|
||||
for v in self._flattened_typing_literal_args(annotation))
|
||||
|
||||
if param.default is not param.empty:
|
||||
# We don't want None or '' to trigger the [name=value] case and instead it should
|
||||
@ -1038,7 +1073,7 @@ class Command(_BaseCommand):
|
||||
result.append(f'[{name}...]')
|
||||
elif greedy:
|
||||
result.append(f'[{name}]...')
|
||||
elif self._is_typing_optional(param.annotation):
|
||||
elif optional:
|
||||
result.append(f'[{name}]')
|
||||
else:
|
||||
result.append(f'<{name}>')
|
||||
|
Reference in New Issue
Block a user