Rewrite HTTP handling significantly.
This should have a more uniform approach to rate limit handling. Instead of queueing every request, wait until we receive a 429 and then block the requesting bucket until we're done being rate limited. This should reduce the number of 429s done by the API significantly (about 66% avg). This also consistently checks for 502 retries across all requests.
This commit is contained in:
		| @@ -42,6 +42,7 @@ from .enums import ChannelType, ServerRegion | ||||
| from .voice_client import VoiceClient | ||||
| from .iterators import LogsFromIterator | ||||
| from .gateway import * | ||||
| from .http import HTTPClient | ||||
|  | ||||
| import asyncio | ||||
| import aiohttp | ||||
| @@ -52,7 +53,6 @@ import sys, re | ||||
| import tempfile, os, hashlib | ||||
| import itertools | ||||
| import datetime | ||||
| from random import randint as random_integer | ||||
| from collections import namedtuple | ||||
|  | ||||
| PY35 = sys.version_info >= (3, 5) | ||||
| @@ -136,16 +136,8 @@ class Client: | ||||
|  | ||||
|         self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop) | ||||
|  | ||||
|         # Blame Jake for this | ||||
|         user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' | ||||
|  | ||||
|         self.headers = { | ||||
|             'content-type': 'application/json', | ||||
|             'user-agent': user_agent.format(library_version, sys.version_info, aiohttp.__version__) | ||||
|         } | ||||
|  | ||||
|         connector = options.pop('connector', None) | ||||
|         self.session = aiohttp.ClientSession(loop=self.loop, connector=connector) | ||||
|         self.http = HTTPClient(connector, loop=self.loop) | ||||
|  | ||||
|         self._closed = asyncio.Event(loop=self.loop) | ||||
|         self._is_logged_in = asyncio.Event(loop=self.loop) | ||||
| @@ -157,23 +149,21 @@ class Client: | ||||
|         filename = hashlib.md5(email.encode('utf-8')).hexdigest() | ||||
|         return os.path.join(tempfile.gettempdir(), 'discord_py', filename) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _login_via_cache(self, email, password): | ||||
|     def _get_cache_token(self, email, password): | ||||
|         try: | ||||
|             log.info('attempting to login via cache') | ||||
|             cache_file = self._get_cache_filename(email) | ||||
|             self.email = email | ||||
|             with open(cache_file, 'r') as f: | ||||
|                 log.info('login cache file found') | ||||
|                 self.token = f.read() | ||||
|                 self.headers['authorization'] = self.token | ||||
|                 return f.read() | ||||
|  | ||||
|             # at this point our check failed | ||||
|             # so we have to login and get the proper token and then | ||||
|             # redo the cache | ||||
|         except OSError: | ||||
|             log.info('a problem occurred while opening login cache') | ||||
|             pass # file not found et al | ||||
|             return None # file not found et al | ||||
|  | ||||
|     def _update_cache(self, email, password): | ||||
|         try: | ||||
| @@ -222,20 +212,30 @@ class Client: | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _resolve_destination(self, destination): | ||||
|         if isinstance(destination, (Channel, PrivateChannel, Server)): | ||||
|             return destination.id | ||||
|         if isinstance(destination, Channel): | ||||
|             return destination.id, destination.server.id | ||||
|         elif isinstance(destination, PrivateChannel): | ||||
|             return destination.id, None | ||||
|         elif isinstance(destination, Server): | ||||
|             return destination.id, destination.id | ||||
|         elif isinstance(destination, User): | ||||
|             found = self.connection._get_private_channel_by_user(destination.id) | ||||
|             if found is None: | ||||
|                 # Couldn't find the user, so start a PM with them first. | ||||
|                 channel = yield from self.start_private_message(destination) | ||||
|                 return channel.id | ||||
|                 return channel.id, None | ||||
|             else: | ||||
|                 return found.id | ||||
|                 return found.id, None | ||||
|         elif isinstance(destination, Object): | ||||
|             return destination.id | ||||
|             found = self.get_channel(destination.id) | ||||
|             if found is not None: | ||||
|                 return (yield from self._resolve_destination(found)) | ||||
|  | ||||
|             # couldn't find it in cache so YOLO | ||||
|             return destination.id, destination.id | ||||
|         else: | ||||
|             raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object') | ||||
|             fmt = 'Destination must be Channel, PrivateChannel, User, or Object. Received {0.__class__.__name__}' | ||||
|             raise InvalidArgument(fmt.format(destination)) | ||||
|  | ||||
|     def __getattr__(self, name): | ||||
|         if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'): | ||||
| @@ -291,55 +291,25 @@ class Client: | ||||
|     @asyncio.coroutine | ||||
|     def _login_1(self, token, **kwargs): | ||||
|         log.info('logging in using static token') | ||||
|         self.token = token | ||||
|         self.email = None | ||||
|         if kwargs.pop('bot', True): | ||||
|             self.headers['authorization'] = 'Bot ' + self.token | ||||
|         else: | ||||
|             self.headers['authorization'] = self.token | ||||
|  | ||||
|         resp = yield from self.session.get(endpoints.ME, headers=self.headers) | ||||
|         yield from resp.release() | ||||
|         log.debug(request_logging_format.format(method='GET', response=resp)) | ||||
|  | ||||
|         if resp.status != 200: | ||||
|             if resp.status == 401: | ||||
|                 raise LoginFailure('Improper token has been passed.') | ||||
|             else: | ||||
|                 raise HTTPException(resp, None) | ||||
|  | ||||
|         log.info('token auth returned status code {}'.format(resp.status)) | ||||
|         yield from self.http.static_login(token, bot=kwargs.pop('bot', True)) | ||||
|         self._is_logged_in.set() | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _login_2(self, email, password, **kwargs): | ||||
|         # attempt to read the token from cache | ||||
|         if self.cache_auth: | ||||
|             yield from self._login_via_cache(email, password) | ||||
|             if self.is_logged_in: | ||||
|             token = self._get_cache_token() | ||||
|             try: | ||||
|                 self.http.static_login(token, bot=False) | ||||
|             except: | ||||
|                 log.info('cache auth token is out of date') | ||||
|             else: | ||||
|                 self._is_logged_in.set() | ||||
|                 return | ||||
|  | ||||
|         payload = { | ||||
|             'email': email, | ||||
|             'password': password | ||||
|         } | ||||
|  | ||||
|         data = utils.to_json(payload) | ||||
|         resp = yield from self.session.post(endpoints.LOGIN, data=data, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='POST', response=resp)) | ||||
|         if resp.status != 200: | ||||
|             yield from resp.release() | ||||
|             if resp.status == 400: | ||||
|                 raise LoginFailure('Improper credentials have been passed.') | ||||
|             else: | ||||
|                 raise HTTPException(resp, None) | ||||
|  | ||||
|         log.info('logging in returned status code {}'.format(resp.status)) | ||||
|         yield from self.http.email_login(email, password) | ||||
|         self.email = email | ||||
|  | ||||
|         body = yield from resp.json(encoding='utf-8') | ||||
|         self.token = body['token'] | ||||
|         self.headers['authorization'] = self.token | ||||
|         self._is_logged_in.set() | ||||
|  | ||||
|         # since we went through all this trouble | ||||
| @@ -395,12 +365,10 @@ class Client: | ||||
|     def logout(self): | ||||
|         """|coro| | ||||
|  | ||||
|         Logs out of Discord and closes all connections.""" | ||||
|         response = yield from self.session.post(endpoints.LOGOUT, headers=self.headers) | ||||
|         yield from response.release() | ||||
|         Logs out of Discord and closes all connections. | ||||
|         """ | ||||
|         yield from self.close() | ||||
|         self._is_logged_in.clear() | ||||
|         log.debug(request_logging_format.format(method='POST', response=response)) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def connect(self): | ||||
| @@ -453,7 +421,7 @@ class Client: | ||||
|             yield from self.ws.close() | ||||
|  | ||||
|  | ||||
|         yield from self.session.close() | ||||
|         yield from self.http.close() | ||||
|         self._closed.set() | ||||
|         self._is_ready.clear() | ||||
|  | ||||
| @@ -774,43 +742,11 @@ class Client: | ||||
|         if not isinstance(user, User): | ||||
|             raise InvalidArgument('user argument must be a User') | ||||
|  | ||||
|         payload = { | ||||
|             'recipient_id': user.id | ||||
|         } | ||||
|  | ||||
|         url = '{}/channels'.format(endpoints.ME) | ||||
|         r = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='POST', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|         data = yield from r.json(encoding='utf-8') | ||||
|         log.debug(request_success_log.format(response=r, json=payload, data=data)) | ||||
|         data = yield from self.http.start_private_message(user.id) | ||||
|         channel = PrivateChannel(id=data['id'], user=user) | ||||
|         self.connection._add_private_channel(channel) | ||||
|         return channel | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _retry_helper(self, name, *args, retries=0, **kwargs): | ||||
|         req_kwargs = {'headers': self.headers} | ||||
|         req_kwargs.update(kwargs) | ||||
|         resp = yield from self.session.request(*args, **req_kwargs) | ||||
|         tmp = request_logging_format.format(method=resp.method, response=resp) | ||||
|         log_fmt = 'In {}, {}'.format(name, tmp) | ||||
|         log.debug(log_fmt) | ||||
|  | ||||
|         if resp.status == 502 and retries < 5: | ||||
|             # retry the 502 request unconditionally | ||||
|             log.info('Retrying the 502 request to ' + name) | ||||
|             yield from asyncio.sleep(retries + 1) | ||||
|             return (yield from self._retry_helper(name, *args, retries=retries + 1, **kwargs)) | ||||
|  | ||||
|         if resp.status == 429: | ||||
|             retry = float(resp.headers['Retry-After']) / 1000.0 | ||||
|             yield from resp.release() | ||||
|             yield from asyncio.sleep(retry) | ||||
|             return (yield from self._retry_helper(name, *args, retries=retries, **kwargs)) | ||||
|  | ||||
|         return resp | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def send_message(self, destination, content, *, tts=False): | ||||
|         """|coro| | ||||
| @@ -858,23 +794,11 @@ class Client: | ||||
|             The message that was sent. | ||||
|         """ | ||||
|  | ||||
|         channel_id = yield from self._resolve_destination(destination) | ||||
|         channel_id, guild_id = yield from self._resolve_destination(destination) | ||||
|  | ||||
|         content = str(content) | ||||
|  | ||||
|         url = '{base}/{id}/messages'.format(base=endpoints.CHANNELS, id=channel_id) | ||||
|         payload = { | ||||
|             'content': content, | ||||
|             'nonce': random_integer(-2**63, 2**63 - 1) | ||||
|         } | ||||
|  | ||||
|         if tts: | ||||
|             payload['tts'] = True | ||||
|  | ||||
|         resp = yield from self._retry_helper('send_message', 'POST', url, data=utils.to_json(payload)) | ||||
|         yield from utils._verify_successful_response(resp) | ||||
|         data = yield from resp.json(encoding='utf-8') | ||||
|         log.debug(request_success_log.format(response=resp, json=payload, data=data)) | ||||
|         data = yield from self.http.send_message(channel_id, content, guild_id=guild_id, tts=tts) | ||||
|         channel = self.get_channel(data.get('channel_id')) | ||||
|         message = Message(channel=channel, **data) | ||||
|         return message | ||||
| @@ -895,14 +819,8 @@ class Client: | ||||
|             The location to send the typing update. | ||||
|         """ | ||||
|  | ||||
|         channel_id = yield from self._resolve_destination(destination) | ||||
|  | ||||
|         url = '{base}/{id}/typing'.format(base=endpoints.CHANNELS, id=channel_id) | ||||
|  | ||||
|         response = yield from self.session.post(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='POST', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         channel_id, guild_id = yield from self._resolve_destination(destination) | ||||
|         yield from self.http.send_typing(channel_id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def send_file(self, destination, fp, *, filename=None, content=None, tts=False): | ||||
| @@ -951,34 +869,18 @@ class Client: | ||||
|             The message sent. | ||||
|         """ | ||||
|  | ||||
|         channel_id = yield from self._resolve_destination(destination) | ||||
|  | ||||
|         url = '{base}/{id}/messages'.format(base=endpoints.CHANNELS, id=channel_id) | ||||
|         form = aiohttp.FormData() | ||||
|  | ||||
|         if content is not None: | ||||
|             form.add_field('content', str(content)) | ||||
|  | ||||
|         form.add_field('tts', 'true' if tts else 'false') | ||||
|  | ||||
|         # we don't want the content-type json in this request | ||||
|         headers = self.headers.copy() | ||||
|         headers.pop('content-type', None) | ||||
|         channel_id, guild_id = yield from self._resolve_destination(destination) | ||||
|  | ||||
|         try: | ||||
|             # attempt to open the file and send the request | ||||
|             with open(fp, 'rb') as f: | ||||
|                 form.add_field('file', f, filename=filename, content_type='application/octet-stream') | ||||
|                 response = yield from self._retry_helper("send_file", "POST", url, data=form, headers=headers) | ||||
|                 buffer = f.read() | ||||
|                 if filename is None: | ||||
|                     filename = fp | ||||
|         except TypeError: | ||||
|             form.add_field('file', fp, filename=filename, content_type='application/octet-stream') | ||||
|             response = yield from self._retry_helper("send_file", "POST", url, data=form, headers=headers) | ||||
|             buffer = fp | ||||
|  | ||||
|         log.debug(request_logging_format.format(method='POST', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         data = yield from response.json(encoding='utf-8') | ||||
|         msg = 'POST {0.url} returned {0.status} with {1} response' | ||||
|         log.debug(msg.format(response, data)) | ||||
|         data = yield from self.http.send_file(channel_id, buffer, guild_id=guild_id, | ||||
|                                               filename=filename, content=content, tts=tts) | ||||
|         channel = self.get_channel(data.get('channel_id')) | ||||
|         message = Message(channel=channel, **data) | ||||
|         return message | ||||
| @@ -1004,12 +906,8 @@ class Client: | ||||
|         HTTPException | ||||
|             Deleting the message failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, message.channel.id, message.id) | ||||
|         response = yield from self.session.delete(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='DELETE', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         channel = message.channel | ||||
|         yield from self.http.delete_message(channel.id, message.id, channel.server.id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def delete_messages(self, messages): | ||||
| @@ -1045,16 +943,9 @@ class Client: | ||||
|         if len(messages) > 100 or len(messages) < 2: | ||||
|             raise ClientException('Can only delete messages in the range of [2, 100]') | ||||
|  | ||||
|         channel_id = messages[0].channel.id | ||||
|         url = '{0}/{1}/messages/bulk_delete'.format(endpoints.CHANNELS, channel_id) | ||||
|         payload = { | ||||
|             'messages': [m.id for m in messages] | ||||
|         } | ||||
|  | ||||
|         response = yield from self.session.post(url, headers=self.headers, data=utils.to_json(payload)) | ||||
|         log.debug(request_logging_format.format(method='POST', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         channel = messages[0].channel | ||||
|         message_ids = [m.id for m in messages] | ||||
|         yield from self.http.delete_messages(channel.id, message_ids, channel.server.id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def purge_from(self, channel, *, limit=100, check=None, before=None, after=None): | ||||
| @@ -1179,19 +1070,9 @@ class Client: | ||||
|         channel = message.channel | ||||
|         content = str(new_content) | ||||
|  | ||||
|         url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, channel.id, message.id) | ||||
|         payload = { | ||||
|             'content': content | ||||
|         } | ||||
|  | ||||
|         response = yield from self._retry_helper('edit_message', 'PATCH', url, data=utils.to_json(payload)) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         data = yield from response.json(encoding='utf-8') | ||||
|         log.debug(request_success_log.format(response=response, json=payload, data=data)) | ||||
|         data = yield from self.http.edit_message(message.id, channel.id, content, guild_id=channel.server.id) | ||||
|         return Message(channel=channel, **data) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _logs_from(self, channel, limit=100, before=None, after=None): | ||||
|         """|coro| | ||||
|  | ||||
| @@ -1242,21 +1123,7 @@ class Client: | ||||
|                 if message.author == client.user: | ||||
|                     counter += 1 | ||||
|         """ | ||||
|         url = '{}/{}/messages'.format(endpoints.CHANNELS, channel.id) | ||||
|         params = { | ||||
|             'limit': limit | ||||
|         } | ||||
|  | ||||
|         if before: | ||||
|             params['before'] = before.id | ||||
|         if after: | ||||
|             params['after'] = after.id | ||||
|  | ||||
|         response = yield from self.session.get(url, params=params, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='GET', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         messages = yield from response.json(encoding='utf-8') | ||||
|         return messages | ||||
|         return self.http.logs_from(channel.id, limit, before=before, after=after) | ||||
|  | ||||
|     if PY35: | ||||
|         def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False): | ||||
| @@ -1356,12 +1223,7 @@ class Client: | ||||
|         HTTPException | ||||
|             Kicking failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) | ||||
|         response = yield from self.session.delete(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='DELETE', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.kick(member.id, member.server.id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def ban(self, member, delete_message_days=1): | ||||
| @@ -1390,16 +1252,7 @@ class Client: | ||||
|         HTTPException | ||||
|             Banning failed. | ||||
|         """ | ||||
|  | ||||
|         params = { | ||||
|             'delete-message-days': delete_message_days | ||||
|         } | ||||
|  | ||||
|         url = '{0}/{1.server.id}/bans/{1.id}'.format(endpoints.SERVERS, member) | ||||
|         response = yield from self.session.put(url, params=params, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PUT', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.ban(member.id, member.server.id, delete_message_days) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def unban(self, server, user): | ||||
| @@ -1421,12 +1274,7 @@ class Client: | ||||
|         HTTPException | ||||
|             Unbanning failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}/bans/{2.id}'.format(endpoints.SERVERS, server, user) | ||||
|         response = yield from self.session.delete(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='DELETE', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.unban(user.id, server.id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def server_voice_state(self, member, *, mute=False, deafen=False): | ||||
| @@ -1456,17 +1304,7 @@ class Client: | ||||
|         HTTPException | ||||
|             The operation failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) | ||||
|         payload = { | ||||
|             'mute': mute, | ||||
|             'deaf': deafen | ||||
|         } | ||||
|  | ||||
|         response = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.server_voice_state(member.id, member.server.id, mute=mute, deafen=deafen) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def edit_profile(self, password=None, **fields): | ||||
| @@ -1527,30 +1365,21 @@ class Client: | ||||
|         if not_bot_account and password is None: | ||||
|             raise ClientException('Password is required for non-bot accounts.') | ||||
|  | ||||
|         payload = { | ||||
|         args = { | ||||
|             'password': password, | ||||
|             'username': fields.get('username', self.user.name), | ||||
|             'avatar': avatar | ||||
|         } | ||||
|  | ||||
|         if not_bot_account: | ||||
|             payload['email'] = fields.get('email', self.email) | ||||
|             args['email'] = fields.get('email', self.email) | ||||
|  | ||||
|             if 'new_password' in fields: | ||||
|                 payload['new_password'] = fields['new_password'] | ||||
|  | ||||
|  | ||||
|         r = yield from self.session.patch(endpoints.ME, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|  | ||||
|         data = yield from r.json(encoding='utf-8') | ||||
|         log.debug(request_success_log.format(response=r, json=payload, data=data)) | ||||
|                 args['new_password'] = fields['new_password'] | ||||
|  | ||||
|         yield from self.http.edit_profile(**args) | ||||
|         if not_bot_account: | ||||
|             self.token = data['token'] | ||||
|             self.email = data['email'] | ||||
|             self.headers['authorization'] = self.token | ||||
|  | ||||
|             if self.cache_auth: | ||||
|                 self._update_cache(self.email, password) | ||||
| @@ -1608,24 +1437,12 @@ class Client: | ||||
|             Changing the nickname failed. | ||||
|         """ | ||||
|  | ||||
|         nickname = nickname if nickname else '' | ||||
|  | ||||
|         if member == self.user: | ||||
|             fmt = '{0}/{1.server.id}/members/@me/nick' | ||||
|             yield from self.http.change_my_nickname(member.server.id, nickname) | ||||
|         else: | ||||
|             fmt = '{0}/{1.server.id}/members/{1.id}' | ||||
|  | ||||
|         url = fmt.format(endpoints.SERVERS, member) | ||||
|  | ||||
|         payload = { | ||||
|             # oddly enough, this endpoint requires '' to clear the nickname | ||||
|             # instead of the more consistent 'null', this might change in the | ||||
|             # future, or not. | ||||
|             'nick': nickname if nickname else '' | ||||
|         } | ||||
|  | ||||
|         r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|         yield from r.release() | ||||
|             yield from self.http.change_nickname(member.server.id, member.id, nickname) | ||||
|  | ||||
|     # Channel management | ||||
|  | ||||
| @@ -1662,26 +1479,7 @@ class Client: | ||||
|             Editing the channel failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}'.format(endpoints.CHANNELS, channel) | ||||
|         payload = { | ||||
|             'name': options.get('name', channel.name), | ||||
|             'topic': options.get('topic', channel.topic), | ||||
|         } | ||||
|  | ||||
|         user_limit = options.get('user_limit') | ||||
|         if user_limit is not None: | ||||
|             payload['user_limit'] = user_limit | ||||
|  | ||||
|         bitrate = options.get('bitrate') | ||||
|         if bitrate is not None: | ||||
|             payload['bitrate'] = bitrate | ||||
|  | ||||
|         r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|  | ||||
|         data = yield from r.json(encoding='utf-8') | ||||
|         log.debug(request_success_log.format(response=r, json=payload, data=data)) | ||||
|         yield from self.http.edit_channel(channel.id, **options) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def move_channel(self, channel, position): | ||||
| @@ -1735,13 +1533,7 @@ class Client: | ||||
|             channels.insert(position, channel) | ||||
|  | ||||
|         payload = [{'id': c.id, 'position': index } for index, c in enumerate(channels)] | ||||
|  | ||||
|         r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|  | ||||
|         yield from r.release() | ||||
|         log.debug(request_success_log.format(json=payload, response=r, data={})) | ||||
|         yield from self.http.patch(url, json=payload, bucket='move_channel') | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def create_channel(self, server, name, type=None): | ||||
| @@ -1779,18 +1571,7 @@ class Client: | ||||
|         if type is None: | ||||
|             type = ChannelType.text | ||||
|  | ||||
|         payload = { | ||||
|             'name': name, | ||||
|             'type': str(type) | ||||
|         } | ||||
|  | ||||
|         url = '{0}/{1.id}/channels'.format(endpoints.SERVERS, server) | ||||
|         response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='POST', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|  | ||||
|         data = yield from response.json(encoding='utf-8') | ||||
|         log.debug(request_success_log.format(response=response, data=data, json=payload)) | ||||
|         data = yield from self.http.create_channel(server.id, name, str(type)) | ||||
|         channel = Channel(server=server, **data) | ||||
|         return channel | ||||
|  | ||||
| @@ -1817,12 +1598,7 @@ class Client: | ||||
|         HTTPException | ||||
|             Deleting the channel failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{}/{}'.format(endpoints.CHANNELS, channel.id) | ||||
|         response = yield from self.session.delete(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='DELETE', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.delete_channel(channel.id) | ||||
|  | ||||
|     # Server management | ||||
|  | ||||
| @@ -1847,12 +1623,7 @@ class Client: | ||||
|         HTTPException | ||||
|             If leaving the server failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{}/@me/guilds/{.id}'.format(endpoints.USERS, server) | ||||
|         response = yield from self.session.delete(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='DELETE', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.leave_server(server.id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def delete_server(self, server): | ||||
| @@ -1874,11 +1645,7 @@ class Client: | ||||
|             You do not have permissions to delete the server. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}'.format(endpoints.SERVERS, server) | ||||
|         response = yield from self.session.delete(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='DELETE', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.delete_server(server.id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def create_server(self, name, region=None, icon=None): | ||||
| @@ -1918,17 +1685,7 @@ class Client: | ||||
|         else: | ||||
|             region = region.name | ||||
|  | ||||
|         payload = { | ||||
|             'icon': icon, | ||||
|             'name': name, | ||||
|             'region': region | ||||
|         } | ||||
|  | ||||
|         r = yield from self.session.post(endpoints.SERVERS, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='POST', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|         data = yield from r.json(encoding='utf-8') | ||||
|         log.debug(request_success_log.format(response=r, json=payload, data=data)) | ||||
|         data = yield from self.http.create_server(name, region, icon) | ||||
|         return Server(**data) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
| @@ -1984,30 +1741,18 @@ class Client: | ||||
|             else: | ||||
|                 icon = None | ||||
|  | ||||
|         payload = { | ||||
|             'region': str(fields.get('region', server.region)), | ||||
|             'afk_timeout': fields.get('afk_timeout', server.afk_timeout), | ||||
|             'icon': icon, | ||||
|             'name': fields.get('name', server.name), | ||||
|         } | ||||
|  | ||||
|         afk_channel = fields.get('afk_channel') | ||||
|         if afk_channel is None: | ||||
|             afk_channel = server.afk_channel | ||||
|  | ||||
|         payload['afk_channel'] = getattr(afk_channel, 'id', None) | ||||
|         fields['icon'] = icon | ||||
|         if 'afk_channel' in fields: | ||||
|             fields['afk_channel_id'] = fields['afk_channel'].id | ||||
|  | ||||
|         if 'owner' in fields: | ||||
|             if server.owner != server.me: | ||||
|                 raise InvalidArgument('To transfer ownership you must be the owner of the server.') | ||||
|  | ||||
|             payload['owner_id'] = fields['owner'].id | ||||
|             fields['owner_id'] = fields['owner'].id | ||||
|  | ||||
|         yield from self.http.edit_server(server.id, **fields) | ||||
|  | ||||
|         url = '{0}/{1.id}'.format(endpoints.SERVERS, server) | ||||
|         r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|         yield from r.release() | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def get_bans(self, server): | ||||
| @@ -2036,11 +1781,7 @@ class Client: | ||||
|             A list of :class:`User` that have been banned. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}/bans'.format(endpoints.SERVERS, server) | ||||
|         resp = yield from self.session.get(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='GET', response=resp)) | ||||
|         yield from utils._verify_successful_response(resp) | ||||
|         data = yield from resp.json(encoding='utf-8') | ||||
|         data = yield from self.http.get_bans(server.id) | ||||
|         return [User(**user['user']) for user in data] | ||||
|  | ||||
|     # Invite management | ||||
| @@ -2092,20 +1833,7 @@ class Client: | ||||
|             The invite that was created. | ||||
|         """ | ||||
|  | ||||
|         payload = { | ||||
|             'max_age': options.get('max_age', 0), | ||||
|             'max_uses': options.get('max_uses', 0), | ||||
|             'temporary': options.get('temporary', False), | ||||
|             'xkcdpass': options.get('xkcd', False) | ||||
|         } | ||||
|  | ||||
|         url = '{0}/{1.id}/invites'.format(endpoints.CHANNELS, destination) | ||||
|         response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='POST', response=response)) | ||||
|  | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         data = yield from response.json(encoding='utf-8') | ||||
|         log.debug(request_success_log.format(json=payload, response=response, data=data)) | ||||
|         data = yield from self.http.create_invite(destination.id, **options) | ||||
|         self._fill_invite_data(data) | ||||
|         return Invite(**data) | ||||
|  | ||||
| @@ -2139,12 +1867,8 @@ class Client: | ||||
|             The invite from the URL/ID. | ||||
|         """ | ||||
|  | ||||
|         destination = self._resolve_invite(url) | ||||
|         rurl = '{0}/invite/{1}'.format(endpoints.API_BASE, destination) | ||||
|         response = yield from self.session.get(rurl, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='GET', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         data = yield from response.json(encoding='utf-8') | ||||
|         invite_id = self._resolve_invite(url) | ||||
|         data = yield from self.http.get_invite(invite_id) | ||||
|         self._fill_invite_data(data) | ||||
|         return Invite(**data) | ||||
|  | ||||
| @@ -2174,11 +1898,7 @@ class Client: | ||||
|             The list of invites that are currently active. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}/invites'.format(endpoints.SERVERS, server) | ||||
|         resp = yield from self.session.get(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='GET', response=resp)) | ||||
|         yield from utils._verify_successful_response(resp) | ||||
|         data = yield from resp.json(encoding='utf-8') | ||||
|         data = yield from self.http.invites_from(server.id) | ||||
|         result = [] | ||||
|         for invite in data: | ||||
|             channel = server.get_channel(invite['channel']['id']) | ||||
| @@ -2210,12 +1930,8 @@ class Client: | ||||
|             The invite is invalid or expired. | ||||
|         """ | ||||
|  | ||||
|         destination = self._resolve_invite(invite) | ||||
|         url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination) | ||||
|         response = yield from self.session.post(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='POST', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         invite_id = self._resolve_invite(invite) | ||||
|         yield from self.http.accept_invite(invite_id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def delete_invite(self, invite): | ||||
| @@ -2241,12 +1957,8 @@ class Client: | ||||
|             Revoking the invite failed. | ||||
|         """ | ||||
|  | ||||
|         destination = self._resolve_invite(invite) | ||||
|         url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination) | ||||
|         response = yield from self.session.delete(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='DELETE', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         invite_id = self._resolve_invite(invite) | ||||
|         yield from self.http.delete_invite(invite_id) | ||||
|  | ||||
|     # Role management | ||||
|  | ||||
| @@ -2298,13 +2010,7 @@ class Client: | ||||
|             roles.append(role.id) | ||||
|  | ||||
|         payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)] | ||||
|  | ||||
|         r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|  | ||||
|         data = yield from r.json() | ||||
|         log.debug(request_success_log.format(json=payload, response=r, data=data)) | ||||
|         yield from self.http.patch(url, json=payload, bucket='move_role') | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def edit_role(self, server, role, **fields): | ||||
| @@ -2345,11 +2051,6 @@ class Client: | ||||
|             Editing the role failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role) | ||||
|         color = fields.get('color') | ||||
|         if color is None: | ||||
|             color = fields.get('colour', role.colour) | ||||
|  | ||||
|         payload = { | ||||
|             'name': fields.get('name', role.name), | ||||
|             'permissions': fields.get('permissions', role.permissions).value, | ||||
| @@ -2358,12 +2059,7 @@ class Client: | ||||
|             'mentionable': fields.get('mentionable', role.mentionable) | ||||
|         } | ||||
|  | ||||
|         r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|  | ||||
|         data = yield from r.json(encoding='utf-8') | ||||
|         log.debug(request_success_log.format(json=payload, response=r, data=data)) | ||||
|         yield from self.http.edit_role(server.id, role.id, **payload) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def delete_role(self, server, role): | ||||
| @@ -2386,24 +2082,11 @@ class Client: | ||||
|             Deleting the role failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role) | ||||
|         response = yield from self.session.delete(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='DELETE', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.delete_role(server.id, role.id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def _replace_roles(self, member, roles): | ||||
|         url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) | ||||
|  | ||||
|         payload = { | ||||
|             'roles': roles | ||||
|         } | ||||
|  | ||||
|         r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|         yield from r.release() | ||||
|         yield from self.http.replace_roles(member.id, member.server.id, roles) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def add_roles(self, member, *roles): | ||||
| @@ -2521,12 +2204,7 @@ class Client: | ||||
|             is stored in cache. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}/roles'.format(endpoints.SERVERS, server) | ||||
|         r = yield from self.session.post(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='POST', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|  | ||||
|         data = yield from r.json(encoding='utf-8') | ||||
|         data = yield from self.http.create_role(server.id) | ||||
|         role = Role(server=server, **data) | ||||
|  | ||||
|         # we have to call edit because you can't pass a payload to the | ||||
| @@ -2581,8 +2259,6 @@ class Client: | ||||
|             or the target type was not :class:`Role` or :class:`Member`. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target) | ||||
|  | ||||
|         allow = Permissions.none() if allow is None else allow | ||||
|         deny = Permissions.none() if deny is None else deny | ||||
|  | ||||
| @@ -2592,23 +2268,14 @@ class Client: | ||||
|         deny =  deny.value | ||||
|         allow = allow.value | ||||
|  | ||||
|         payload = { | ||||
|             'id': target.id, | ||||
|             'allow': allow, | ||||
|             'deny': deny | ||||
|         } | ||||
|  | ||||
|         if isinstance(target, Member): | ||||
|             payload['type'] = 'member' | ||||
|             perm_type = 'member' | ||||
|         elif isinstance(target, Role): | ||||
|             payload['type'] = 'role' | ||||
|             perm_type = 'role' | ||||
|         else: | ||||
|             raise InvalidArgument('target parameter must be either discord.Member or discord.Role') | ||||
|  | ||||
|         r = yield from self.session.put(url, data=utils.to_json(payload), headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PUT', response=r)) | ||||
|         yield from utils._verify_successful_response(r) | ||||
|         yield from r.release() | ||||
|         yield from self.http.edit_channel_permissions(channel.id, target.id, allow, deny, perm_type) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def delete_channel_permissions(self, channel, target): | ||||
| @@ -2637,12 +2304,7 @@ class Client: | ||||
|         HTTPException | ||||
|             Deleting channel specific permissions failed. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target) | ||||
|         response = yield from self.session.delete(url, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='DELETE', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.delete_channel_permissions(channel.id, target.id) | ||||
|  | ||||
|     # Voice management | ||||
|  | ||||
| @@ -2676,18 +2338,10 @@ class Client: | ||||
|             You do not have permissions to move the member. | ||||
|         """ | ||||
|  | ||||
|         url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member) | ||||
|  | ||||
|         if getattr(channel, 'type', ChannelType.text) != ChannelType.voice: | ||||
|             raise InvalidArgument('The channel provided must be a voice channel.') | ||||
|  | ||||
|         payload = utils.to_json({ | ||||
|             'channel_id': channel.id | ||||
|         }) | ||||
|         response = yield from self.session.patch(url, data=payload, headers=self.headers) | ||||
|         log.debug(request_logging_format.format(method='PATCH', response=response)) | ||||
|         yield from utils._verify_successful_response(response) | ||||
|         yield from response.release() | ||||
|         yield from self.http.move_member(member.id, member.server.id, channel.id) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def join_voice_channel(self, channel): | ||||
| @@ -2817,10 +2471,7 @@ class Client: | ||||
|         HTTPException | ||||
|             Retrieving the information failed somehow. | ||||
|         """ | ||||
|         url = '{}/@me'.format(endpoints.APPLICATIONS) | ||||
|         resp = yield from self.session.get(url, headers=self.headers) | ||||
|         yield from utils._verify_successful_response(resp) | ||||
|         data = yield from resp.json() | ||||
|         data = yield from self.http.application_info() | ||||
|         return AppInfo(id=data['id'], name=data['name'], | ||||
|                        description=data['description'], icon=data['icon']) | ||||
|  | ||||
|   | ||||
| @@ -40,7 +40,7 @@ import struct | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| __all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket', | ||||
| __all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket', | ||||
|             'KeepAliveHandler', 'VoiceKeepAliveHandler', | ||||
|             'DiscordVoiceWebSocket', 'ResumeWebSocket' ] | ||||
|  | ||||
| @@ -97,36 +97,6 @@ class VoiceKeepAliveHandler(KeepAliveHandler): | ||||
|             'd': int(time.time() * 1000) | ||||
|         } | ||||
|  | ||||
|  | ||||
| @asyncio.coroutine | ||||
| def get_gateway(token, *, loop=None): | ||||
|     """Returns the gateway URL for connecting to the WebSocket. | ||||
|  | ||||
|     Parameters | ||||
|     ----------- | ||||
|     token : str | ||||
|         The discord authentication token. | ||||
|     loop | ||||
|         The event loop. | ||||
|  | ||||
|     Raises | ||||
|     ------ | ||||
|     GatewayNotFound | ||||
|         When the gateway is not returned gracefully. | ||||
|     """ | ||||
|     headers = { | ||||
|         'authorization': token, | ||||
|         'content-type': 'application/json' | ||||
|     } | ||||
|  | ||||
|     with aiohttp.ClientSession(loop=loop) as session: | ||||
|         resp = yield from session.get(endpoints.GATEWAY, headers=headers) | ||||
|         if resp.status != 200: | ||||
|             yield from resp.release() | ||||
|             raise GatewayNotFound() | ||||
|         data = yield from resp.json(encoding='utf-8') | ||||
|         return data.get('url') + '?encoding=json&v=4' | ||||
|  | ||||
| class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|     """Implements a WebSocket for Discord's gateway v4. | ||||
|  | ||||
| @@ -190,11 +160,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|  | ||||
|         This is for internal use only. | ||||
|         """ | ||||
|         gateway = yield from get_gateway(client.token, loop=client.loop) | ||||
|         gateway = yield from client.http.get_gateway() | ||||
|         ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) | ||||
|  | ||||
|         # dynamically add attributes needed | ||||
|         ws.token = client.token | ||||
|         ws.token = client.http.token | ||||
|         ws._connection = client.connection | ||||
|         ws._dispatch = client.dispatch | ||||
|         ws.gateway = gateway | ||||
| @@ -505,7 +475,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|                 'server_id': client.guild_id, | ||||
|                 'user_id': client.user.id, | ||||
|                 'session_id': client.session_id, | ||||
|                 'token': client.token | ||||
|                 'token': client.http.token | ||||
|             } | ||||
|         } | ||||
|  | ||||
|   | ||||
							
								
								
									
										484
									
								
								discord/http.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										484
									
								
								discord/http.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,484 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
|  | ||||
| """ | ||||
| The MIT License (MIT) | ||||
|  | ||||
| Copyright (c) 2015-2016 Rapptz | ||||
|  | ||||
| Permission is hereby granted, free of charge, to any person obtaining a | ||||
| copy of this software and associated documentation files (the "Software"), | ||||
| to deal in the Software without restriction, including without limitation | ||||
| the rights to use, copy, modify, merge, publish, distribute, sublicense, | ||||
| and/or sell copies of the Software, and to permit persons to whom the | ||||
| Software is furnished to do so, subject to the following conditions: | ||||
|  | ||||
| The above copyright notice and this permission notice shall be included in | ||||
| all copies or substantial portions of the Software. | ||||
|  | ||||
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS | ||||
| OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||
| FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | ||||
| DEALINGS IN THE SOFTWARE. | ||||
| """ | ||||
|  | ||||
| import aiohttp | ||||
| import asyncio | ||||
| import json | ||||
| import sys | ||||
| import logging | ||||
| import io | ||||
| import inspect | ||||
| import weakref | ||||
| from random import randint as random_integer | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| from .errors import HTTPException, Forbidden, NotFound, LoginFailure, GatewayNotFound | ||||
| from . import utils, __version__ | ||||
|  | ||||
| @asyncio.coroutine | ||||
| def json_or_text(response): | ||||
|     text = yield from response.text(encoding='utf-8') | ||||
|     if response.headers['content-type'] == 'application/json': | ||||
|         return json.loads(text) | ||||
|     return text | ||||
|  | ||||
| def _func_(): | ||||
|     # emulate __func__ from C++ | ||||
|     return inspect.currentframe().f_back.f_code.co_name | ||||
|  | ||||
| class HTTPClient: | ||||
|     """Represents an HTTP client sending HTTP requests to the Discord API.""" | ||||
|  | ||||
|     BASE          = 'https://discordapp.com' | ||||
|     API_BASE      = BASE     + '/api' | ||||
|     GATEWAY       = API_BASE + '/gateway' | ||||
|     USERS         = API_BASE + '/users' | ||||
|     ME            = USERS    + '/@me' | ||||
|     REGISTER      = API_BASE + '/auth/register' | ||||
|     LOGIN         = API_BASE + '/auth/login' | ||||
|     LOGOUT        = API_BASE + '/auth/logout' | ||||
|     GUILDS        = API_BASE + '/guilds' | ||||
|     CHANNELS      = API_BASE + '/channels' | ||||
|     APPLICATIONS  = API_BASE + '/oauth2/applications' | ||||
|  | ||||
|     SUCCESS_LOG = '{method} {url} with {json} has received {text}' | ||||
|     REQUEST_LOG = '{method} {url} has returned {status}' | ||||
|  | ||||
|     def __init__(self, connector=None, *, loop=None): | ||||
|         self.loop = asyncio.get_event_loop() if loop is None else loop | ||||
|         self.connector = connector | ||||
|         self.session = aiohttp.ClientSession(connector=connector, loop=self.loop) | ||||
|         self._locks = weakref.WeakValueDictionary() | ||||
|         self.token = None | ||||
|         self.bot_token = False | ||||
|  | ||||
|         user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' | ||||
|         self.user_agent = user_agent.format(__version__, sys.version_info, aiohttp.__version__) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def request(self, method, url, *, bucket=None, **kwargs): | ||||
|         lock = self._locks.get(bucket) | ||||
|         if lock is None: | ||||
|             lock = asyncio.Lock(loop=self.loop) | ||||
|             if bucket is not None: | ||||
|                 self._locks[bucket] = lock | ||||
|  | ||||
|         # header creation | ||||
|         headers = { | ||||
|             'User-Agent': self.user_agent, | ||||
|         } | ||||
|  | ||||
|         if self.token is not None: | ||||
|             headers['Authorization'] = 'Bot ' + self.token if self.bot_token else self.token | ||||
|  | ||||
|         # some checking if it's a JSON request | ||||
|         if 'json' in kwargs: | ||||
|             headers['Content-Type'] = 'application/json' | ||||
|             kwargs['data'] = utils.to_json(kwargs.pop('json')) | ||||
|  | ||||
|         kwargs['headers'] = headers | ||||
|         with (yield from lock): | ||||
|             for tries in range(5): | ||||
|                 r = yield from self.session.request(method, url, **kwargs) | ||||
|                 log.debug(self.REQUEST_LOG.format(method=method, url=url, status=r.status)) | ||||
|                 try: | ||||
|                     # even errors have text involved in them so this is safe to call | ||||
|                     data = yield from json_or_text(r) | ||||
|  | ||||
|                     # the request was successful so just return the text/json | ||||
|                     if 300 > r.status >= 200: | ||||
|                         log.debug(self.SUCCESS_LOG.format(method=method, url=url, | ||||
|                                                           json=kwargs.get('data'), text=data)) | ||||
|                         return data | ||||
|  | ||||
|                     # we are being rate limited | ||||
|                     if r.status == 429: | ||||
|                         fmt = 'We are being rate limited. Retrying in {:.2} seconds. Handled under the bucket "{}"' | ||||
|  | ||||
|                         # sleep a bit | ||||
|                         retry_after = data['retry_after'] / 1000.0 | ||||
|                         log.info(fmt.format(retry_after, bucket)) | ||||
|                         yield from asyncio.sleep(retry_after) | ||||
|                         continue | ||||
|  | ||||
|                     # we've received a 502, unconditional retry | ||||
|                     if r.status == 502 and tries <= 5: | ||||
|                         yield from asyncio.sleep(1 + tries * 2) | ||||
|                         continue | ||||
|  | ||||
|                     # the usual error cases | ||||
|                     if r.status == 403: | ||||
|                         raise Forbidden(r, data) | ||||
|                     elif r.status == 404: | ||||
|                         raise NotFound(r, data) | ||||
|                     else: | ||||
|                         raise HTTPException(r, data) | ||||
|                 finally: | ||||
|                     # clean-up just in case | ||||
|                     yield from r.release() | ||||
|  | ||||
|     def get(self, *args, **kwargs): | ||||
|         return self.request('GET', *args, **kwargs) | ||||
|  | ||||
|     def put(self, *args, **kwargs): | ||||
|         return self.request('PUT', *args, **kwargs) | ||||
|  | ||||
|     def patch(self, *args, **kwargs): | ||||
|         return self.request('PATCH', *args, **kwargs) | ||||
|  | ||||
|     def delete(self, *args, **kwargs): | ||||
|         return self.request('DELETE', *args, **kwargs) | ||||
|  | ||||
|     def post(self, *args, **kwargs): | ||||
|         return self.request('POST', *args, **kwargs) | ||||
|  | ||||
|     # state management | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def close(self): | ||||
|         yield from self.session.close() | ||||
|  | ||||
|     def recreate(self): | ||||
|         self.session = aiohttp.ClientSession(self.connector, loop=self.loop) | ||||
|  | ||||
|     def _token(self, token, *, bot=True): | ||||
|         self.token = token | ||||
|         self.bot_token = bot | ||||
|  | ||||
|     # login management | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def email_login(self, email, password): | ||||
|         payload = { | ||||
|             'email': email, | ||||
|             'password': password | ||||
|         } | ||||
|  | ||||
|         try: | ||||
|             data = yield from self.post(self.LOGIN, json=payload, bucket=_func_()) | ||||
|         except HTTPException as e: | ||||
|             if e.response.status == 400: | ||||
|                 raise LoginFailure('Improper credentials have been passed.') from e | ||||
|             raise | ||||
|  | ||||
|         self._token(data['token'], bot=False) | ||||
|         return data | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def static_login(self, token, *, bot): | ||||
|         old_state = (self.token, self.bot_token) | ||||
|         self._token(token, bot=bot) | ||||
|  | ||||
|         try: | ||||
|             data = yield from self.get(self.ME) | ||||
|         except HTTPException as e: | ||||
|             self._token(*old_state) | ||||
|             if e.response.status == 401: | ||||
|                 raise LoginFailure('Improper token has been passed.') from e | ||||
|             raise e | ||||
|  | ||||
|         return data | ||||
|  | ||||
|     def logout(self): | ||||
|         return self.post(self.LOGOUT, bucket=_func_()) | ||||
|  | ||||
|     # Message management | ||||
|  | ||||
|     def start_private_message(self, user_id): | ||||
|         payload = { | ||||
|             'recipient_id': user_id | ||||
|         } | ||||
|  | ||||
|         return self.post(self.ME + '/channels', json=payload, bucket=_func_()) | ||||
|  | ||||
|     def send_message(self, channel_id, content, *, guild_id=None, tts=False): | ||||
|         url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) | ||||
|         payload = { | ||||
|             'content': str(content), | ||||
|             'nonce': random_integer(-2**63, 2**63 - 1) | ||||
|         } | ||||
|  | ||||
|         if tts: | ||||
|             payload['tts'] = True | ||||
|  | ||||
|         return self.post(url, json=payload, bucket='messages:' + str(guild_id)) | ||||
|  | ||||
|     def send_typing(self, channel_id): | ||||
|         url = '{0.CHANNELS}/{1}/typing'.format(self, channel_id) | ||||
|         return self.post(url, bucket=_func_()) | ||||
|  | ||||
|     def send_file(self, channel_id, buffer, *, guild_id=None, filename=None, content=None, tts=False): | ||||
|         url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) | ||||
|         form = aiohttp.FormData() | ||||
|  | ||||
|         if content is not None: | ||||
|             form.add_field('content', str(content)) | ||||
|  | ||||
|         form.add_field('tts', 'true' if tts else 'false') | ||||
|         form.add_field('file', io.BytesIO(buffer), filename=filename, content_type='application/octet-stream') | ||||
|  | ||||
|         return self.post(url, data=form, bucket='messages:' + str(guild_id)) | ||||
|  | ||||
|     def delete_message(self, channel_id, message_id, guild_id=None): | ||||
|         url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id) | ||||
|         bucket = '{}:{}'.format(_func_(), guild_id) | ||||
|         return self.delete(url, bucket=bucket) | ||||
|  | ||||
|     def delete_messages(self, channel_id, message_ids, guild_id=None): | ||||
|         url = '{0.CHANNELS}/{1}/messages/bulk_delete'.format(self, channel_id) | ||||
|         payload = { | ||||
|             'messages': message_ids | ||||
|         } | ||||
|         bucket = '{}:{}'.format(_func_(), guild_id) | ||||
|         return self.post(url, json=payload, bucket=bucket) | ||||
|  | ||||
|     def edit_message(self, message_id, channel_id, content, *, guild_id=None): | ||||
|         url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id) | ||||
|         payload = { | ||||
|             'content': str(content) | ||||
|         } | ||||
|         return self.patch(url, json=payload, bucket='messages:' + str(guild_id)) | ||||
|  | ||||
|  | ||||
|     def logs_from(self, channel_id, limit, before=None, after=None): | ||||
|         url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id) | ||||
|         params = { | ||||
|             'limit': limit | ||||
|         } | ||||
|  | ||||
|         if before: | ||||
|             params['before'] = before | ||||
|         if after: | ||||
|             params['after'] = after | ||||
|  | ||||
|         return self.get(url, params=params, bucket=_func_()) | ||||
|  | ||||
|     # Member management | ||||
|  | ||||
|     def kick(self, user_id, guild_id): | ||||
|         url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) | ||||
|         return self.delete(url, bucket=_func_()) | ||||
|  | ||||
|     def ban(self, user_id, guild_id, delete_message_days=1): | ||||
|         url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id) | ||||
|         params = { | ||||
|             'delete-message-days': delete_message_days | ||||
|         } | ||||
|         return self.put(url, params=params, bucket=_func_()) | ||||
|  | ||||
|     def unban(self, user_id, guild_id): | ||||
|         url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id) | ||||
|         return self.delete(url, bucket=_func_()) | ||||
|  | ||||
|     def server_voice_state(self, user_id, guild_id, *, mute=False, deafen=False): | ||||
|         url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) | ||||
|         payload = { | ||||
|             'mute': mute, | ||||
|             'deafen': deafen | ||||
|         } | ||||
|         return self.patch(url, json=payload, bucket='members:' + str(guild_id)) | ||||
|  | ||||
|     def edit_profile(self, password, username, avatar, **fields): | ||||
|         payload = { | ||||
|             'password': password, | ||||
|             'username': username, | ||||
|             'avatar': avatar | ||||
|         } | ||||
|  | ||||
|         if 'email' in fields: | ||||
|             payload['email'] = fields['email'] | ||||
|  | ||||
|         if 'new_password' in fields: | ||||
|             payload['new_password'] = fields['new_password'] | ||||
|  | ||||
|         return self.patch(self.ME, json=payload, bucket=_func_()) | ||||
|  | ||||
|     def change_my_nickname(self, guild_id, nickname): | ||||
|         url = '{0.GUILDS}/{1}/members/@me/nick'.format(self, guild_id) | ||||
|         payload = { | ||||
|             'nick': nickname | ||||
|         } | ||||
|         bucket = '{}:{}'.format(_func_(), guild_id) | ||||
|         return self.patch(url, json=payload, bucket=bucket) | ||||
|  | ||||
|     def change_nickname(self, guild_id, user_id, nickname): | ||||
|         url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) | ||||
|         payload = { | ||||
|             'nick': nickname | ||||
|         } | ||||
|         bucket = '{}:{}'.format(_func_(), guild_id) | ||||
|         return self.patch(url, json=payload, bucket=bucket) | ||||
|  | ||||
|     # Channel management | ||||
|  | ||||
|     def edit_channel(self, channel_id, **options): | ||||
|         url = '{0.CHANNELS}/{1}'.format(self, channel_id) | ||||
|  | ||||
|         valid_keys = ('name', 'topic', 'bitrate', 'user_limit') | ||||
|         payload = { | ||||
|             k: v for k, v in options.items() if k in valid_keys | ||||
|         } | ||||
|  | ||||
|         return self.patch(url, json=payload, bucket=_func_()) | ||||
|  | ||||
|     def create_channel(self, guild_id, name, channe_type): | ||||
|         url = '{0.GUILDS}/{1}/channels'.format(self, guild_id) | ||||
|         payload = { | ||||
|             'name': name, | ||||
|             'type': channe_type | ||||
|         } | ||||
|  | ||||
|         return self.post(url, json=payload, bucket=_func_()) | ||||
|  | ||||
|     def delete_channel(self, channel_id): | ||||
|         url = '{0.CHANNELS}/{1}'.format(self, channel_id) | ||||
|         return self.delete(url, bucket=_func_()) | ||||
|  | ||||
|     # Server management | ||||
|  | ||||
|     def leave_server(self, guild_id): | ||||
|         url = '{0.USERS}/@me/guilds/{1}'.format(self, guild_id) | ||||
|         return self.delete(url, bucket=_func_()) | ||||
|  | ||||
|     def delete_server(self, guild_id): | ||||
|         url = '{0.GUILDS}/{1}'.format(self, guild_id) | ||||
|         return self.delete(url, bucket=_func_()) | ||||
|  | ||||
|     def create_server(self, name, region, icon): | ||||
|         payload = { | ||||
|             'name': name, | ||||
|             'icon': icon, | ||||
|             'region': region | ||||
|         } | ||||
|  | ||||
|         return self.post(self.GUILDS, json=payload, bucket=_func_()) | ||||
|  | ||||
|     def edit_server(self, guild_id, **fields): | ||||
|         valid_keys = ('name', 'region', 'icon', 'afk_timeout', 'owner_id', | ||||
|                       'afk_channel_id', 'splash', 'verification_level') | ||||
|  | ||||
|         payload = { | ||||
|             k: v for k, v in fields.items() if k in valid_keys | ||||
|         } | ||||
|  | ||||
|         url = '{0.GUILDS}/{1}'.format(self, guild_id) | ||||
|         return self.patch(url, json=payload, bucket=_func_()) | ||||
|  | ||||
|     def get_bans(self, guild_id): | ||||
|         url = '{0.GUILDS}/{1}/bans'.format(self, guild_id) | ||||
|         return self.get(url, bucket=_func_()) | ||||
|  | ||||
|     # Invite management | ||||
|  | ||||
|     def create_invite(self, channel_id, **options): | ||||
|         url = '{0.CHANNELS}/{1}/invites'.format(self, channel_id) | ||||
|         payload = { | ||||
|             'max_age': options.get('max_age', 0), | ||||
|             'max_uses': options.get('max_uses', 0), | ||||
|             'temporary': options.get('temporary', False), | ||||
|             'xkcdpass': options.get('xkcd', False) | ||||
|         } | ||||
|  | ||||
|         return self.post(url, json=payload, bucket=_func_()) | ||||
|  | ||||
|     def get_invite(self, invite_id): | ||||
|         url = '{0.API_BASE}/invite/{1}'.format(self, invite_id) | ||||
|         return self.get(url, bucket=_func_()) | ||||
|  | ||||
|     def invites_from(self, guild_id): | ||||
|         url = '{0.GUILDS}/{1}/invites'.format(self, guild_id) | ||||
|         return self.get(url, bucket=_func_()) | ||||
|  | ||||
|     def accept_invite(self, invite_id): | ||||
|         url = '{0.API_BASE}/invite/{1}'.format(self, invite_id) | ||||
|         return self.post(url, bucket=_func_()) | ||||
|  | ||||
|     def delete_invite(self, invite_id): | ||||
|         url = '{0.API_BASE}/invite/{1}'.format(self, invite_id) | ||||
|         return self.delete(url, bucket=_func_()) | ||||
|  | ||||
|     # Role management | ||||
|  | ||||
|     def edit_role(self, guild_id, role_id, **fields): | ||||
|         url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id) | ||||
|         valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable') | ||||
|         payload = { | ||||
|             k: v for k, v in fields.items() if k in valid_keys | ||||
|         } | ||||
|         return self.patch(url, json=payload, bucket='roles:' + str(guild_id)) | ||||
|  | ||||
|     def delete_role(self, guild_id, role_id): | ||||
|         url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id) | ||||
|         return self.delete(url, bucket=_func_()) | ||||
|  | ||||
|     def replace_roles(self, user_id, guild_id, role_ids): | ||||
|         url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) | ||||
|         payload = { | ||||
|             'roles': role_ids | ||||
|         } | ||||
|         return self.patch(url, json=payload, bucket='members:' + str(guild_id)) | ||||
|  | ||||
|     def create_role(self, guild_id): | ||||
|         url = '{0.GUILDS}/{1}/roles'.format(self, guild_id) | ||||
|         return self.post(url, bucket=_func_()) | ||||
|  | ||||
|     def edit_channel_permissions(self, channel_id, target, allow, deny, type): | ||||
|         url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target) | ||||
|         payload = { | ||||
|             'id': target, | ||||
|             'allow': allow, | ||||
|             'deny': deny, | ||||
|             'type': type | ||||
|         } | ||||
|         return self.put(url, json=payload, bucket=_func_()) | ||||
|  | ||||
|     def delete_channel_permissions(self, channel_id, target): | ||||
|         url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target) | ||||
|         return self.delete(url, bucket=_func_()) | ||||
|  | ||||
|     # Voice management | ||||
|  | ||||
|     def move_member(self, user_id, guild_id, channel_id): | ||||
|         url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id) | ||||
|         payload = { | ||||
|             'channel_id': channel_id | ||||
|         } | ||||
|         return self.patch(url, json=payload, bucket='members:' + str(guild_id)) | ||||
|  | ||||
|     # Misc | ||||
|  | ||||
|     def application_info(self): | ||||
|         url = '{0.APPLICATIONS}/@me'.format(self) | ||||
|         return self.get(url, bucket=_func_()) | ||||
|  | ||||
|     @asyncio.coroutine | ||||
|     def get_gateway(self): | ||||
|         try: | ||||
|             data = yield from self.get(self.GATEWAY, bucket=_func_()) | ||||
|         except HTTPException as e: | ||||
|             raise GatewayNotFound() from e | ||||
|         return data.get('url') + '?encoding=json&v=4' | ||||
		Reference in New Issue
	
	Block a user