Take back ownership of files from aiohttp for retrying requests.

Fix #1809
This commit is contained in:
Rapptz
2019-03-18 07:54:36 -04:00
parent 8ba48c14a7
commit 5e65ec978c
4 changed files with 62 additions and 22 deletions

View File

@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE.
"""
import os.path
import io
class File:
"""A parameter object used for :meth:`abc.Messageable.send`
@ -52,11 +53,28 @@ class File:
Whether the attachment is a spoiler.
"""
__slots__ = ('fp', 'filename', '_true_fp')
__slots__ = ('fp', 'filename', '_original_pos', '_owner', '_closer')
def __init__(self, fp, filename=None, *, spoiler=False):
self.fp = fp
self._true_fp = None
if isinstance(fp, io.IOBase):
if not (fp.seekable() and fp.readable()):
raise ValueError('File buffer {!r} must be seekable and readable'.format(fp))
self.fp = fp
self._original_pos = fp.tell()
self._owner = False
else:
self.fp = open(fp, 'rb')
self._original_pos = 0
self._owner = True
# aiohttp only uses two methods from IOBase
# read and close, since I want to control when the files
# close, I need to stub it so it doesn't close unless
# I tell it to
self._closer = self.fp.close
self.fp.close = lambda: None
if filename is None:
if isinstance(fp, str):
@ -66,15 +84,22 @@ class File:
else:
self.filename = filename
if spoiler and not self.filename.startswith('SPOILER_'):
if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'):
self.filename = 'SPOILER_' + self.filename
def open_file(self):
fp = self.fp
if isinstance(fp, str):
self._true_fp = fp = open(fp, 'rb')
return fp
def reset(self, *, seek=True):
# The `seek` parameter is needed because
# the retry-loop is iterated over multiple times
# starting from 0, as an implementation quirk
# the resetting must be done at the beginning
# before a request is done, since the first index
# is 0, and thus false, then this prevents an
# unnecessary seek since it's the first request
# done.
if seek:
self.fp.seek(self._original_pos)
def close(self):
if self._true_fp:
self._true_fp.close()
self.fp.close = self._closer
if self._owner:
self._closer()