mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-22 08:44:10 +00:00
Take back ownership of files from aiohttp for retrying requests.
Fix #1809
This commit is contained in:
parent
8ba48c14a7
commit
5e65ec978c
@ -756,7 +756,7 @@ class Messageable(metaclass=abc.ABCMeta):
|
||||
raise InvalidArgument('file parameter must be File')
|
||||
|
||||
try:
|
||||
data = await state.http.send_files(channel.id, files=[(file.open_file(), file.filename)],
|
||||
data = await state.http.send_files(channel.id, files=[file],
|
||||
content=content, tts=tts, embed=embed, nonce=nonce)
|
||||
finally:
|
||||
file.close()
|
||||
|
@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
import os.path
|
||||
import io
|
||||
|
||||
class File:
|
||||
"""A parameter object used for :meth:`abc.Messageable.send`
|
||||
@ -52,11 +53,28 @@ class File:
|
||||
Whether the attachment is a spoiler.
|
||||
"""
|
||||
|
||||
__slots__ = ('fp', 'filename', '_true_fp')
|
||||
__slots__ = ('fp', 'filename', '_original_pos', '_owner', '_closer')
|
||||
|
||||
def __init__(self, fp, filename=None, *, spoiler=False):
|
||||
self.fp = fp
|
||||
self._true_fp = None
|
||||
|
||||
if isinstance(fp, io.IOBase):
|
||||
if not (fp.seekable() and fp.readable()):
|
||||
raise ValueError('File buffer {!r} must be seekable and readable'.format(fp))
|
||||
self.fp = fp
|
||||
self._original_pos = fp.tell()
|
||||
self._owner = False
|
||||
else:
|
||||
self.fp = open(fp, 'rb')
|
||||
self._original_pos = 0
|
||||
self._owner = True
|
||||
|
||||
# aiohttp only uses two methods from IOBase
|
||||
# read and close, since I want to control when the files
|
||||
# close, I need to stub it so it doesn't close unless
|
||||
# I tell it to
|
||||
self._closer = self.fp.close
|
||||
self.fp.close = lambda: None
|
||||
|
||||
if filename is None:
|
||||
if isinstance(fp, str):
|
||||
@ -66,15 +84,22 @@ class File:
|
||||
else:
|
||||
self.filename = filename
|
||||
|
||||
if spoiler and not self.filename.startswith('SPOILER_'):
|
||||
if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'):
|
||||
self.filename = 'SPOILER_' + self.filename
|
||||
|
||||
def open_file(self):
|
||||
fp = self.fp
|
||||
if isinstance(fp, str):
|
||||
self._true_fp = fp = open(fp, 'rb')
|
||||
return fp
|
||||
def reset(self, *, seek=True):
|
||||
# The `seek` parameter is needed because
|
||||
# the retry-loop is iterated over multiple times
|
||||
# starting from 0, as an implementation quirk
|
||||
# the resetting must be done at the beginning
|
||||
# before a request is done, since the first index
|
||||
# is 0, and thus false, then this prevents an
|
||||
# unnecessary seek since it's the first request
|
||||
# done.
|
||||
if seek:
|
||||
self.fp.seek(self._original_pos)
|
||||
|
||||
def close(self):
|
||||
if self._true_fp:
|
||||
self._true_fp.close()
|
||||
self.fp.close = self._closer
|
||||
if self._owner:
|
||||
self._closer()
|
||||
|
@ -105,7 +105,7 @@ class HTTPClient:
|
||||
if self._session.closed:
|
||||
self._session = aiohttp.ClientSession(connector=self.connector, loop=self.loop)
|
||||
|
||||
async def request(self, route, *, header_bypass_delay=None, **kwargs):
|
||||
async def request(self, route, *, files=None, header_bypass_delay=None, **kwargs):
|
||||
bucket = route.bucket
|
||||
method = route.method
|
||||
url = route.url
|
||||
@ -151,6 +151,10 @@ class HTTPClient:
|
||||
await lock.acquire()
|
||||
with MaybeUnlock(lock) as maybe_lock:
|
||||
for tries in range(5):
|
||||
if files:
|
||||
for f in files:
|
||||
f.reset(seek=tries)
|
||||
|
||||
async with self._session.request(method, url, **kwargs) as r:
|
||||
log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status)
|
||||
|
||||
@ -334,13 +338,13 @@ class HTTPClient:
|
||||
|
||||
form.add_field('payload_json', utils.to_json(payload))
|
||||
if len(files) == 1:
|
||||
fp = files[0]
|
||||
form.add_field('file', fp[0], filename=fp[1], content_type='application/octet-stream')
|
||||
file = files[0]
|
||||
form.add_field('file', file.fp, filename=file.filename, content_type='application/octet-stream')
|
||||
else:
|
||||
for index, (buffer, filename) in enumerate(files):
|
||||
form.add_field('file%s' % index, buffer, filename=filename, content_type='application/octet-stream')
|
||||
for index, file in enumerate(files):
|
||||
form.add_field('file%s' % index, file.fp, filename=file.filename, content_type='application/octet-stream')
|
||||
|
||||
return self.request(r, data=form)
|
||||
return self.request(r, data=form, files=files)
|
||||
|
||||
async def ack_message(self, channel_id, message_id):
|
||||
r = Route('POST', '/channels/{channel_id}/messages/{message_id}/ack', channel_id=channel_id, message_id=message_id)
|
||||
|
@ -110,17 +110,18 @@ class WebhookAdapter:
|
||||
cleanup = None
|
||||
if file is not None:
|
||||
multipart = {
|
||||
'file': (file.filename, file.open_file(), 'application/octet-stream'),
|
||||
'file': (file.filename, file.fp, 'application/octet-stream'),
|
||||
'payload_json': utils.to_json(payload)
|
||||
}
|
||||
data = None
|
||||
cleanup = file.close
|
||||
files_to_pass = [file]
|
||||
elif files is not None:
|
||||
multipart = {
|
||||
'payload_json': utils.to_json(payload)
|
||||
}
|
||||
for i, file in enumerate(files, start=1):
|
||||
multipart['file%i' % i] = (file.filename, file.open_file(), 'application/octet-stream')
|
||||
multipart['file%i' % i] = (file.filename, file.fp, 'application/octet-stream')
|
||||
data = None
|
||||
|
||||
def _anon():
|
||||
@ -128,13 +129,15 @@ class WebhookAdapter:
|
||||
f.close()
|
||||
|
||||
cleanup = _anon
|
||||
files_to_pass = files
|
||||
else:
|
||||
data = payload
|
||||
multipart = None
|
||||
files_to_pass = None
|
||||
|
||||
url = '%s?wait=%d' % (self._request_url, wait)
|
||||
try:
|
||||
maybe_coro = self.request('POST', url, multipart=multipart, payload=data)
|
||||
maybe_coro = self.request('POST', url, multipart=multipart, payload=data, files=files_to_pass)
|
||||
finally:
|
||||
if cleanup is not None:
|
||||
if not asyncio.iscoroutine(maybe_coro):
|
||||
@ -160,9 +163,10 @@ class AsyncWebhookAdapter(WebhookAdapter):
|
||||
self.session = session
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
async def request(self, verb, url, payload=None, multipart=None):
|
||||
async def request(self, verb, url, payload=None, multipart=None, *, files=None):
|
||||
headers = {}
|
||||
data = None
|
||||
files = files or []
|
||||
if payload:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
data = utils.to_json(payload)
|
||||
@ -176,6 +180,9 @@ class AsyncWebhookAdapter(WebhookAdapter):
|
||||
data.add_field(key, value)
|
||||
|
||||
for tries in range(5):
|
||||
for file in files:
|
||||
file.reset(seek=tries)
|
||||
|
||||
async with self.session.request(verb, url, headers=headers, data=data) as r:
|
||||
data = await r.text(encoding='utf-8')
|
||||
if r.headers['Content-Type'] == 'application/json':
|
||||
@ -239,9 +246,10 @@ class RequestsWebhookAdapter(WebhookAdapter):
|
||||
self.session = session or requests
|
||||
self.sleep = sleep
|
||||
|
||||
def request(self, verb, url, payload=None, multipart=None):
|
||||
def request(self, verb, url, payload=None, multipart=None, *, files=None):
|
||||
headers = {}
|
||||
data = None
|
||||
files = files or []
|
||||
if payload:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
data = utils.to_json(payload)
|
||||
@ -250,6 +258,9 @@ class RequestsWebhookAdapter(WebhookAdapter):
|
||||
data = {'payload_json': multipart.pop('payload_json')}
|
||||
|
||||
for tries in range(5):
|
||||
for file in files:
|
||||
file.reset(seek=tries)
|
||||
|
||||
r = self.session.request(verb, url, headers=headers, data=data, files=multipart)
|
||||
r.encoding = 'utf-8'
|
||||
data = r.text
|
||||
|
Loading…
x
Reference in New Issue
Block a user