mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-06-03 18:42:43 +00:00
Rewrite HTTP handling significantly.
This should have a more uniform approach to rate limit handling. Instead of queueing every request, wait until we receive a 429 and then block the requesting bucket until we're done being rate limited. This should reduce the number of 429s done by the API significantly (about 66% avg). This also consistently checks for 502 retries across all requests.
This commit is contained in:
parent
fa36a449e9
commit
1fba1b06fa
@ -42,6 +42,7 @@ from .enums import ChannelType, ServerRegion
|
||||
from .voice_client import VoiceClient
|
||||
from .iterators import LogsFromIterator
|
||||
from .gateway import *
|
||||
from .http import HTTPClient
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
@ -52,7 +53,6 @@ import sys, re
|
||||
import tempfile, os, hashlib
|
||||
import itertools
|
||||
import datetime
|
||||
from random import randint as random_integer
|
||||
from collections import namedtuple
|
||||
|
||||
PY35 = sys.version_info >= (3, 5)
|
||||
@ -136,16 +136,8 @@ class Client:
|
||||
|
||||
self.connection = ConnectionState(self.dispatch, self.request_offline_members, max_messages, loop=self.loop)
|
||||
|
||||
# Blame Jake for this
|
||||
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
|
||||
|
||||
self.headers = {
|
||||
'content-type': 'application/json',
|
||||
'user-agent': user_agent.format(library_version, sys.version_info, aiohttp.__version__)
|
||||
}
|
||||
|
||||
connector = options.pop('connector', None)
|
||||
self.session = aiohttp.ClientSession(loop=self.loop, connector=connector)
|
||||
self.http = HTTPClient(connector, loop=self.loop)
|
||||
|
||||
self._closed = asyncio.Event(loop=self.loop)
|
||||
self._is_logged_in = asyncio.Event(loop=self.loop)
|
||||
@ -157,23 +149,21 @@ class Client:
|
||||
filename = hashlib.md5(email.encode('utf-8')).hexdigest()
|
||||
return os.path.join(tempfile.gettempdir(), 'discord_py', filename)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _login_via_cache(self, email, password):
|
||||
def _get_cache_token(self, email, password):
|
||||
try:
|
||||
log.info('attempting to login via cache')
|
||||
cache_file = self._get_cache_filename(email)
|
||||
self.email = email
|
||||
with open(cache_file, 'r') as f:
|
||||
log.info('login cache file found')
|
||||
self.token = f.read()
|
||||
self.headers['authorization'] = self.token
|
||||
return f.read()
|
||||
|
||||
# at this point our check failed
|
||||
# so we have to login and get the proper token and then
|
||||
# redo the cache
|
||||
except OSError:
|
||||
log.info('a problem occurred while opening login cache')
|
||||
pass # file not found et al
|
||||
return None # file not found et al
|
||||
|
||||
def _update_cache(self, email, password):
|
||||
try:
|
||||
@ -222,20 +212,30 @@ class Client:
|
||||
|
||||
@asyncio.coroutine
|
||||
def _resolve_destination(self, destination):
|
||||
if isinstance(destination, (Channel, PrivateChannel, Server)):
|
||||
return destination.id
|
||||
if isinstance(destination, Channel):
|
||||
return destination.id, destination.server.id
|
||||
elif isinstance(destination, PrivateChannel):
|
||||
return destination.id, None
|
||||
elif isinstance(destination, Server):
|
||||
return destination.id, destination.id
|
||||
elif isinstance(destination, User):
|
||||
found = self.connection._get_private_channel_by_user(destination.id)
|
||||
if found is None:
|
||||
# Couldn't find the user, so start a PM with them first.
|
||||
channel = yield from self.start_private_message(destination)
|
||||
return channel.id
|
||||
return channel.id, None
|
||||
else:
|
||||
return found.id
|
||||
return found.id, None
|
||||
elif isinstance(destination, Object):
|
||||
return destination.id
|
||||
found = self.get_channel(destination.id)
|
||||
if found is not None:
|
||||
return (yield from self._resolve_destination(found))
|
||||
|
||||
# couldn't find it in cache so YOLO
|
||||
return destination.id, destination.id
|
||||
else:
|
||||
raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object')
|
||||
fmt = 'Destination must be Channel, PrivateChannel, User, or Object. Received {0.__class__.__name__}'
|
||||
raise InvalidArgument(fmt.format(destination))
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'):
|
||||
@ -291,55 +291,25 @@ class Client:
|
||||
@asyncio.coroutine
|
||||
def _login_1(self, token, **kwargs):
|
||||
log.info('logging in using static token')
|
||||
self.token = token
|
||||
self.email = None
|
||||
if kwargs.pop('bot', True):
|
||||
self.headers['authorization'] = 'Bot ' + self.token
|
||||
else:
|
||||
self.headers['authorization'] = self.token
|
||||
|
||||
resp = yield from self.session.get(endpoints.ME, headers=self.headers)
|
||||
yield from resp.release()
|
||||
log.debug(request_logging_format.format(method='GET', response=resp))
|
||||
|
||||
if resp.status != 200:
|
||||
if resp.status == 401:
|
||||
raise LoginFailure('Improper token has been passed.')
|
||||
else:
|
||||
raise HTTPException(resp, None)
|
||||
|
||||
log.info('token auth returned status code {}'.format(resp.status))
|
||||
yield from self.http.static_login(token, bot=kwargs.pop('bot', True))
|
||||
self._is_logged_in.set()
|
||||
|
||||
@asyncio.coroutine
|
||||
def _login_2(self, email, password, **kwargs):
|
||||
# attempt to read the token from cache
|
||||
if self.cache_auth:
|
||||
yield from self._login_via_cache(email, password)
|
||||
if self.is_logged_in:
|
||||
token = self._get_cache_token()
|
||||
try:
|
||||
self.http.static_login(token, bot=False)
|
||||
except:
|
||||
log.info('cache auth token is out of date')
|
||||
else:
|
||||
self._is_logged_in.set()
|
||||
return
|
||||
|
||||
payload = {
|
||||
'email': email,
|
||||
'password': password
|
||||
}
|
||||
|
||||
data = utils.to_json(payload)
|
||||
resp = yield from self.session.post(endpoints.LOGIN, data=data, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='POST', response=resp))
|
||||
if resp.status != 200:
|
||||
yield from resp.release()
|
||||
if resp.status == 400:
|
||||
raise LoginFailure('Improper credentials have been passed.')
|
||||
else:
|
||||
raise HTTPException(resp, None)
|
||||
|
||||
log.info('logging in returned status code {}'.format(resp.status))
|
||||
yield from self.http.email_login(email, password)
|
||||
self.email = email
|
||||
|
||||
body = yield from resp.json(encoding='utf-8')
|
||||
self.token = body['token']
|
||||
self.headers['authorization'] = self.token
|
||||
self._is_logged_in.set()
|
||||
|
||||
# since we went through all this trouble
|
||||
@ -395,12 +365,10 @@ class Client:
|
||||
def logout(self):
|
||||
"""|coro|
|
||||
|
||||
Logs out of Discord and closes all connections."""
|
||||
response = yield from self.session.post(endpoints.LOGOUT, headers=self.headers)
|
||||
yield from response.release()
|
||||
Logs out of Discord and closes all connections.
|
||||
"""
|
||||
yield from self.close()
|
||||
self._is_logged_in.clear()
|
||||
log.debug(request_logging_format.format(method='POST', response=response))
|
||||
|
||||
@asyncio.coroutine
|
||||
def connect(self):
|
||||
@ -453,7 +421,7 @@ class Client:
|
||||
yield from self.ws.close()
|
||||
|
||||
|
||||
yield from self.session.close()
|
||||
yield from self.http.close()
|
||||
self._closed.set()
|
||||
self._is_ready.clear()
|
||||
|
||||
@ -774,43 +742,11 @@ class Client:
|
||||
if not isinstance(user, User):
|
||||
raise InvalidArgument('user argument must be a User')
|
||||
|
||||
payload = {
|
||||
'recipient_id': user.id
|
||||
}
|
||||
|
||||
url = '{}/channels'.format(endpoints.ME)
|
||||
r = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='POST', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
data = yield from r.json(encoding='utf-8')
|
||||
log.debug(request_success_log.format(response=r, json=payload, data=data))
|
||||
data = yield from self.http.start_private_message(user.id)
|
||||
channel = PrivateChannel(id=data['id'], user=user)
|
||||
self.connection._add_private_channel(channel)
|
||||
return channel
|
||||
|
||||
@asyncio.coroutine
|
||||
def _retry_helper(self, name, *args, retries=0, **kwargs):
|
||||
req_kwargs = {'headers': self.headers}
|
||||
req_kwargs.update(kwargs)
|
||||
resp = yield from self.session.request(*args, **req_kwargs)
|
||||
tmp = request_logging_format.format(method=resp.method, response=resp)
|
||||
log_fmt = 'In {}, {}'.format(name, tmp)
|
||||
log.debug(log_fmt)
|
||||
|
||||
if resp.status == 502 and retries < 5:
|
||||
# retry the 502 request unconditionally
|
||||
log.info('Retrying the 502 request to ' + name)
|
||||
yield from asyncio.sleep(retries + 1)
|
||||
return (yield from self._retry_helper(name, *args, retries=retries + 1, **kwargs))
|
||||
|
||||
if resp.status == 429:
|
||||
retry = float(resp.headers['Retry-After']) / 1000.0
|
||||
yield from resp.release()
|
||||
yield from asyncio.sleep(retry)
|
||||
return (yield from self._retry_helper(name, *args, retries=retries, **kwargs))
|
||||
|
||||
return resp
|
||||
|
||||
@asyncio.coroutine
|
||||
def send_message(self, destination, content, *, tts=False):
|
||||
"""|coro|
|
||||
@ -858,23 +794,11 @@ class Client:
|
||||
The message that was sent.
|
||||
"""
|
||||
|
||||
channel_id = yield from self._resolve_destination(destination)
|
||||
channel_id, guild_id = yield from self._resolve_destination(destination)
|
||||
|
||||
content = str(content)
|
||||
|
||||
url = '{base}/{id}/messages'.format(base=endpoints.CHANNELS, id=channel_id)
|
||||
payload = {
|
||||
'content': content,
|
||||
'nonce': random_integer(-2**63, 2**63 - 1)
|
||||
}
|
||||
|
||||
if tts:
|
||||
payload['tts'] = True
|
||||
|
||||
resp = yield from self._retry_helper('send_message', 'POST', url, data=utils.to_json(payload))
|
||||
yield from utils._verify_successful_response(resp)
|
||||
data = yield from resp.json(encoding='utf-8')
|
||||
log.debug(request_success_log.format(response=resp, json=payload, data=data))
|
||||
data = yield from self.http.send_message(channel_id, content, guild_id=guild_id, tts=tts)
|
||||
channel = self.get_channel(data.get('channel_id'))
|
||||
message = Message(channel=channel, **data)
|
||||
return message
|
||||
@ -895,14 +819,8 @@ class Client:
|
||||
The location to send the typing update.
|
||||
"""
|
||||
|
||||
channel_id = yield from self._resolve_destination(destination)
|
||||
|
||||
url = '{base}/{id}/typing'.format(base=endpoints.CHANNELS, id=channel_id)
|
||||
|
||||
response = yield from self.session.post(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='POST', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
channel_id, guild_id = yield from self._resolve_destination(destination)
|
||||
yield from self.http.send_typing(channel_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def send_file(self, destination, fp, *, filename=None, content=None, tts=False):
|
||||
@ -951,34 +869,18 @@ class Client:
|
||||
The message sent.
|
||||
"""
|
||||
|
||||
channel_id = yield from self._resolve_destination(destination)
|
||||
|
||||
url = '{base}/{id}/messages'.format(base=endpoints.CHANNELS, id=channel_id)
|
||||
form = aiohttp.FormData()
|
||||
|
||||
if content is not None:
|
||||
form.add_field('content', str(content))
|
||||
|
||||
form.add_field('tts', 'true' if tts else 'false')
|
||||
|
||||
# we don't want the content-type json in this request
|
||||
headers = self.headers.copy()
|
||||
headers.pop('content-type', None)
|
||||
channel_id, guild_id = yield from self._resolve_destination(destination)
|
||||
|
||||
try:
|
||||
# attempt to open the file and send the request
|
||||
with open(fp, 'rb') as f:
|
||||
form.add_field('file', f, filename=filename, content_type='application/octet-stream')
|
||||
response = yield from self._retry_helper("send_file", "POST", url, data=form, headers=headers)
|
||||
buffer = f.read()
|
||||
if filename is None:
|
||||
filename = fp
|
||||
except TypeError:
|
||||
form.add_field('file', fp, filename=filename, content_type='application/octet-stream')
|
||||
response = yield from self._retry_helper("send_file", "POST", url, data=form, headers=headers)
|
||||
buffer = fp
|
||||
|
||||
log.debug(request_logging_format.format(method='POST', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
data = yield from response.json(encoding='utf-8')
|
||||
msg = 'POST {0.url} returned {0.status} with {1} response'
|
||||
log.debug(msg.format(response, data))
|
||||
data = yield from self.http.send_file(channel_id, buffer, guild_id=guild_id,
|
||||
filename=filename, content=content, tts=tts)
|
||||
channel = self.get_channel(data.get('channel_id'))
|
||||
message = Message(channel=channel, **data)
|
||||
return message
|
||||
@ -1004,12 +906,8 @@ class Client:
|
||||
HTTPException
|
||||
Deleting the message failed.
|
||||
"""
|
||||
|
||||
url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, message.channel.id, message.id)
|
||||
response = yield from self.session.delete(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='DELETE', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
channel = message.channel
|
||||
yield from self.http.delete_message(channel.id, message.id, channel.server.id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def delete_messages(self, messages):
|
||||
@ -1045,16 +943,9 @@ class Client:
|
||||
if len(messages) > 100 or len(messages) < 2:
|
||||
raise ClientException('Can only delete messages in the range of [2, 100]')
|
||||
|
||||
channel_id = messages[0].channel.id
|
||||
url = '{0}/{1}/messages/bulk_delete'.format(endpoints.CHANNELS, channel_id)
|
||||
payload = {
|
||||
'messages': [m.id for m in messages]
|
||||
}
|
||||
|
||||
response = yield from self.session.post(url, headers=self.headers, data=utils.to_json(payload))
|
||||
log.debug(request_logging_format.format(method='POST', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
channel = messages[0].channel
|
||||
message_ids = [m.id for m in messages]
|
||||
yield from self.http.delete_messages(channel.id, message_ids, channel.server.id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def purge_from(self, channel, *, limit=100, check=None, before=None, after=None):
|
||||
@ -1179,19 +1070,9 @@ class Client:
|
||||
channel = message.channel
|
||||
content = str(new_content)
|
||||
|
||||
url = '{}/{}/messages/{}'.format(endpoints.CHANNELS, channel.id, message.id)
|
||||
payload = {
|
||||
'content': content
|
||||
}
|
||||
|
||||
response = yield from self._retry_helper('edit_message', 'PATCH', url, data=utils.to_json(payload))
|
||||
log.debug(request_logging_format.format(method='PATCH', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
data = yield from response.json(encoding='utf-8')
|
||||
log.debug(request_success_log.format(response=response, json=payload, data=data))
|
||||
data = yield from self.http.edit_message(message.id, channel.id, content, guild_id=channel.server.id)
|
||||
return Message(channel=channel, **data)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _logs_from(self, channel, limit=100, before=None, after=None):
|
||||
"""|coro|
|
||||
|
||||
@ -1242,21 +1123,7 @@ class Client:
|
||||
if message.author == client.user:
|
||||
counter += 1
|
||||
"""
|
||||
url = '{}/{}/messages'.format(endpoints.CHANNELS, channel.id)
|
||||
params = {
|
||||
'limit': limit
|
||||
}
|
||||
|
||||
if before:
|
||||
params['before'] = before.id
|
||||
if after:
|
||||
params['after'] = after.id
|
||||
|
||||
response = yield from self.session.get(url, params=params, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='GET', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
messages = yield from response.json(encoding='utf-8')
|
||||
return messages
|
||||
return self.http.logs_from(channel.id, limit, before=before, after=after)
|
||||
|
||||
if PY35:
|
||||
def logs_from(self, channel, limit=100, *, before=None, after=None, reverse=False):
|
||||
@ -1356,12 +1223,7 @@ class Client:
|
||||
HTTPException
|
||||
Kicking failed.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member)
|
||||
response = yield from self.session.delete(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='DELETE', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.kick(member.id, member.server.id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def ban(self, member, delete_message_days=1):
|
||||
@ -1390,16 +1252,7 @@ class Client:
|
||||
HTTPException
|
||||
Banning failed.
|
||||
"""
|
||||
|
||||
params = {
|
||||
'delete-message-days': delete_message_days
|
||||
}
|
||||
|
||||
url = '{0}/{1.server.id}/bans/{1.id}'.format(endpoints.SERVERS, member)
|
||||
response = yield from self.session.put(url, params=params, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PUT', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.ban(member.id, member.server.id, delete_message_days)
|
||||
|
||||
@asyncio.coroutine
|
||||
def unban(self, server, user):
|
||||
@ -1421,12 +1274,7 @@ class Client:
|
||||
HTTPException
|
||||
Unbanning failed.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}/bans/{2.id}'.format(endpoints.SERVERS, server, user)
|
||||
response = yield from self.session.delete(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='DELETE', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.unban(user.id, server.id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def server_voice_state(self, member, *, mute=False, deafen=False):
|
||||
@ -1456,17 +1304,7 @@ class Client:
|
||||
HTTPException
|
||||
The operation failed.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member)
|
||||
payload = {
|
||||
'mute': mute,
|
||||
'deaf': deafen
|
||||
}
|
||||
|
||||
response = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.server_voice_state(member.id, member.server.id, mute=mute, deafen=deafen)
|
||||
|
||||
@asyncio.coroutine
|
||||
def edit_profile(self, password=None, **fields):
|
||||
@ -1527,30 +1365,21 @@ class Client:
|
||||
if not_bot_account and password is None:
|
||||
raise ClientException('Password is required for non-bot accounts.')
|
||||
|
||||
payload = {
|
||||
args = {
|
||||
'password': password,
|
||||
'username': fields.get('username', self.user.name),
|
||||
'avatar': avatar
|
||||
}
|
||||
|
||||
if not_bot_account:
|
||||
payload['email'] = fields.get('email', self.email)
|
||||
args['email'] = fields.get('email', self.email)
|
||||
|
||||
if 'new_password' in fields:
|
||||
payload['new_password'] = fields['new_password']
|
||||
|
||||
|
||||
r = yield from self.session.patch(endpoints.ME, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
|
||||
data = yield from r.json(encoding='utf-8')
|
||||
log.debug(request_success_log.format(response=r, json=payload, data=data))
|
||||
args['new_password'] = fields['new_password']
|
||||
|
||||
yield from self.http.edit_profile(**args)
|
||||
if not_bot_account:
|
||||
self.token = data['token']
|
||||
self.email = data['email']
|
||||
self.headers['authorization'] = self.token
|
||||
|
||||
if self.cache_auth:
|
||||
self._update_cache(self.email, password)
|
||||
@ -1608,24 +1437,12 @@ class Client:
|
||||
Changing the nickname failed.
|
||||
"""
|
||||
|
||||
nickname = nickname if nickname else ''
|
||||
|
||||
if member == self.user:
|
||||
fmt = '{0}/{1.server.id}/members/@me/nick'
|
||||
yield from self.http.change_my_nickname(member.server.id, nickname)
|
||||
else:
|
||||
fmt = '{0}/{1.server.id}/members/{1.id}'
|
||||
|
||||
url = fmt.format(endpoints.SERVERS, member)
|
||||
|
||||
payload = {
|
||||
# oddly enough, this endpoint requires '' to clear the nickname
|
||||
# instead of the more consistent 'null', this might change in the
|
||||
# future, or not.
|
||||
'nick': nickname if nickname else ''
|
||||
}
|
||||
|
||||
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
yield from r.release()
|
||||
yield from self.http.change_nickname(member.server.id, member.id, nickname)
|
||||
|
||||
# Channel management
|
||||
|
||||
@ -1662,26 +1479,7 @@ class Client:
|
||||
Editing the channel failed.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}'.format(endpoints.CHANNELS, channel)
|
||||
payload = {
|
||||
'name': options.get('name', channel.name),
|
||||
'topic': options.get('topic', channel.topic),
|
||||
}
|
||||
|
||||
user_limit = options.get('user_limit')
|
||||
if user_limit is not None:
|
||||
payload['user_limit'] = user_limit
|
||||
|
||||
bitrate = options.get('bitrate')
|
||||
if bitrate is not None:
|
||||
payload['bitrate'] = bitrate
|
||||
|
||||
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
|
||||
data = yield from r.json(encoding='utf-8')
|
||||
log.debug(request_success_log.format(response=r, json=payload, data=data))
|
||||
yield from self.http.edit_channel(channel.id, **options)
|
||||
|
||||
@asyncio.coroutine
|
||||
def move_channel(self, channel, position):
|
||||
@ -1735,13 +1533,7 @@ class Client:
|
||||
channels.insert(position, channel)
|
||||
|
||||
payload = [{'id': c.id, 'position': index } for index, c in enumerate(channels)]
|
||||
|
||||
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
|
||||
yield from r.release()
|
||||
log.debug(request_success_log.format(json=payload, response=r, data={}))
|
||||
yield from self.http.patch(url, json=payload, bucket='move_channel')
|
||||
|
||||
@asyncio.coroutine
|
||||
def create_channel(self, server, name, type=None):
|
||||
@ -1779,18 +1571,7 @@ class Client:
|
||||
if type is None:
|
||||
type = ChannelType.text
|
||||
|
||||
payload = {
|
||||
'name': name,
|
||||
'type': str(type)
|
||||
}
|
||||
|
||||
url = '{0}/{1.id}/channels'.format(endpoints.SERVERS, server)
|
||||
response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='POST', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
|
||||
data = yield from response.json(encoding='utf-8')
|
||||
log.debug(request_success_log.format(response=response, data=data, json=payload))
|
||||
data = yield from self.http.create_channel(server.id, name, str(type))
|
||||
channel = Channel(server=server, **data)
|
||||
return channel
|
||||
|
||||
@ -1817,12 +1598,7 @@ class Client:
|
||||
HTTPException
|
||||
Deleting the channel failed.
|
||||
"""
|
||||
|
||||
url = '{}/{}'.format(endpoints.CHANNELS, channel.id)
|
||||
response = yield from self.session.delete(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='DELETE', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.delete_channel(channel.id)
|
||||
|
||||
# Server management
|
||||
|
||||
@ -1847,12 +1623,7 @@ class Client:
|
||||
HTTPException
|
||||
If leaving the server failed.
|
||||
"""
|
||||
|
||||
url = '{}/@me/guilds/{.id}'.format(endpoints.USERS, server)
|
||||
response = yield from self.session.delete(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='DELETE', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.leave_server(server.id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def delete_server(self, server):
|
||||
@ -1874,11 +1645,7 @@ class Client:
|
||||
You do not have permissions to delete the server.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}'.format(endpoints.SERVERS, server)
|
||||
response = yield from self.session.delete(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='DELETE', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.delete_server(server.id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def create_server(self, name, region=None, icon=None):
|
||||
@ -1918,17 +1685,7 @@ class Client:
|
||||
else:
|
||||
region = region.name
|
||||
|
||||
payload = {
|
||||
'icon': icon,
|
||||
'name': name,
|
||||
'region': region
|
||||
}
|
||||
|
||||
r = yield from self.session.post(endpoints.SERVERS, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='POST', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
data = yield from r.json(encoding='utf-8')
|
||||
log.debug(request_success_log.format(response=r, json=payload, data=data))
|
||||
data = yield from self.http.create_server(name, region, icon)
|
||||
return Server(**data)
|
||||
|
||||
@asyncio.coroutine
|
||||
@ -1984,30 +1741,18 @@ class Client:
|
||||
else:
|
||||
icon = None
|
||||
|
||||
payload = {
|
||||
'region': str(fields.get('region', server.region)),
|
||||
'afk_timeout': fields.get('afk_timeout', server.afk_timeout),
|
||||
'icon': icon,
|
||||
'name': fields.get('name', server.name),
|
||||
}
|
||||
|
||||
afk_channel = fields.get('afk_channel')
|
||||
if afk_channel is None:
|
||||
afk_channel = server.afk_channel
|
||||
|
||||
payload['afk_channel'] = getattr(afk_channel, 'id', None)
|
||||
fields['icon'] = icon
|
||||
if 'afk_channel' in fields:
|
||||
fields['afk_channel_id'] = fields['afk_channel'].id
|
||||
|
||||
if 'owner' in fields:
|
||||
if server.owner != server.me:
|
||||
raise InvalidArgument('To transfer ownership you must be the owner of the server.')
|
||||
|
||||
payload['owner_id'] = fields['owner'].id
|
||||
fields['owner_id'] = fields['owner'].id
|
||||
|
||||
yield from self.http.edit_server(server.id, **fields)
|
||||
|
||||
url = '{0}/{1.id}'.format(endpoints.SERVERS, server)
|
||||
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
yield from r.release()
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_bans(self, server):
|
||||
@ -2036,11 +1781,7 @@ class Client:
|
||||
A list of :class:`User` that have been banned.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}/bans'.format(endpoints.SERVERS, server)
|
||||
resp = yield from self.session.get(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='GET', response=resp))
|
||||
yield from utils._verify_successful_response(resp)
|
||||
data = yield from resp.json(encoding='utf-8')
|
||||
data = yield from self.http.get_bans(server.id)
|
||||
return [User(**user['user']) for user in data]
|
||||
|
||||
# Invite management
|
||||
@ -2092,20 +1833,7 @@ class Client:
|
||||
The invite that was created.
|
||||
"""
|
||||
|
||||
payload = {
|
||||
'max_age': options.get('max_age', 0),
|
||||
'max_uses': options.get('max_uses', 0),
|
||||
'temporary': options.get('temporary', False),
|
||||
'xkcdpass': options.get('xkcd', False)
|
||||
}
|
||||
|
||||
url = '{0}/{1.id}/invites'.format(endpoints.CHANNELS, destination)
|
||||
response = yield from self.session.post(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='POST', response=response))
|
||||
|
||||
yield from utils._verify_successful_response(response)
|
||||
data = yield from response.json(encoding='utf-8')
|
||||
log.debug(request_success_log.format(json=payload, response=response, data=data))
|
||||
data = yield from self.http.create_invite(destination.id, **options)
|
||||
self._fill_invite_data(data)
|
||||
return Invite(**data)
|
||||
|
||||
@ -2139,12 +1867,8 @@ class Client:
|
||||
The invite from the URL/ID.
|
||||
"""
|
||||
|
||||
destination = self._resolve_invite(url)
|
||||
rurl = '{0}/invite/{1}'.format(endpoints.API_BASE, destination)
|
||||
response = yield from self.session.get(rurl, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='GET', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
data = yield from response.json(encoding='utf-8')
|
||||
invite_id = self._resolve_invite(url)
|
||||
data = yield from self.http.get_invite(invite_id)
|
||||
self._fill_invite_data(data)
|
||||
return Invite(**data)
|
||||
|
||||
@ -2174,11 +1898,7 @@ class Client:
|
||||
The list of invites that are currently active.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}/invites'.format(endpoints.SERVERS, server)
|
||||
resp = yield from self.session.get(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='GET', response=resp))
|
||||
yield from utils._verify_successful_response(resp)
|
||||
data = yield from resp.json(encoding='utf-8')
|
||||
data = yield from self.http.invites_from(server.id)
|
||||
result = []
|
||||
for invite in data:
|
||||
channel = server.get_channel(invite['channel']['id'])
|
||||
@ -2210,12 +1930,8 @@ class Client:
|
||||
The invite is invalid or expired.
|
||||
"""
|
||||
|
||||
destination = self._resolve_invite(invite)
|
||||
url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination)
|
||||
response = yield from self.session.post(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='POST', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
invite_id = self._resolve_invite(invite)
|
||||
yield from self.http.accept_invite(invite_id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def delete_invite(self, invite):
|
||||
@ -2241,12 +1957,8 @@ class Client:
|
||||
Revoking the invite failed.
|
||||
"""
|
||||
|
||||
destination = self._resolve_invite(invite)
|
||||
url = '{0}/invite/{1}'.format(endpoints.API_BASE, destination)
|
||||
response = yield from self.session.delete(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='DELETE', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
invite_id = self._resolve_invite(invite)
|
||||
yield from self.http.delete_invite(invite_id)
|
||||
|
||||
# Role management
|
||||
|
||||
@ -2298,13 +2010,7 @@ class Client:
|
||||
roles.append(role.id)
|
||||
|
||||
payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
|
||||
|
||||
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
|
||||
data = yield from r.json()
|
||||
log.debug(request_success_log.format(json=payload, response=r, data=data))
|
||||
yield from self.http.patch(url, json=payload, bucket='move_role')
|
||||
|
||||
@asyncio.coroutine
|
||||
def edit_role(self, server, role, **fields):
|
||||
@ -2345,11 +2051,6 @@ class Client:
|
||||
Editing the role failed.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role)
|
||||
color = fields.get('color')
|
||||
if color is None:
|
||||
color = fields.get('colour', role.colour)
|
||||
|
||||
payload = {
|
||||
'name': fields.get('name', role.name),
|
||||
'permissions': fields.get('permissions', role.permissions).value,
|
||||
@ -2358,12 +2059,7 @@ class Client:
|
||||
'mentionable': fields.get('mentionable', role.mentionable)
|
||||
}
|
||||
|
||||
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
|
||||
data = yield from r.json(encoding='utf-8')
|
||||
log.debug(request_success_log.format(json=payload, response=r, data=data))
|
||||
yield from self.http.edit_role(server.id, role.id, **payload)
|
||||
|
||||
@asyncio.coroutine
|
||||
def delete_role(self, server, role):
|
||||
@ -2386,24 +2082,11 @@ class Client:
|
||||
Deleting the role failed.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}/roles/{2.id}'.format(endpoints.SERVERS, server, role)
|
||||
response = yield from self.session.delete(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='DELETE', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.delete_role(server.id, role.id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def _replace_roles(self, member, roles):
|
||||
url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member)
|
||||
|
||||
payload = {
|
||||
'roles': roles
|
||||
}
|
||||
|
||||
r = yield from self.session.patch(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
yield from r.release()
|
||||
yield from self.http.replace_roles(member.id, member.server.id, roles)
|
||||
|
||||
@asyncio.coroutine
|
||||
def add_roles(self, member, *roles):
|
||||
@ -2521,12 +2204,7 @@ class Client:
|
||||
is stored in cache.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}/roles'.format(endpoints.SERVERS, server)
|
||||
r = yield from self.session.post(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='POST', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
|
||||
data = yield from r.json(encoding='utf-8')
|
||||
data = yield from self.http.create_role(server.id)
|
||||
role = Role(server=server, **data)
|
||||
|
||||
# we have to call edit because you can't pass a payload to the
|
||||
@ -2581,8 +2259,6 @@ class Client:
|
||||
or the target type was not :class:`Role` or :class:`Member`.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target)
|
||||
|
||||
allow = Permissions.none() if allow is None else allow
|
||||
deny = Permissions.none() if deny is None else deny
|
||||
|
||||
@ -2592,23 +2268,14 @@ class Client:
|
||||
deny = deny.value
|
||||
allow = allow.value
|
||||
|
||||
payload = {
|
||||
'id': target.id,
|
||||
'allow': allow,
|
||||
'deny': deny
|
||||
}
|
||||
|
||||
if isinstance(target, Member):
|
||||
payload['type'] = 'member'
|
||||
perm_type = 'member'
|
||||
elif isinstance(target, Role):
|
||||
payload['type'] = 'role'
|
||||
perm_type = 'role'
|
||||
else:
|
||||
raise InvalidArgument('target parameter must be either discord.Member or discord.Role')
|
||||
|
||||
r = yield from self.session.put(url, data=utils.to_json(payload), headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PUT', response=r))
|
||||
yield from utils._verify_successful_response(r)
|
||||
yield from r.release()
|
||||
yield from self.http.edit_channel_permissions(channel.id, target.id, allow, deny, perm_type)
|
||||
|
||||
@asyncio.coroutine
|
||||
def delete_channel_permissions(self, channel, target):
|
||||
@ -2637,12 +2304,7 @@ class Client:
|
||||
HTTPException
|
||||
Deleting channel specific permissions failed.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.id}/permissions/{2.id}'.format(endpoints.CHANNELS, channel, target)
|
||||
response = yield from self.session.delete(url, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='DELETE', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.delete_channel_permissions(channel.id, target.id)
|
||||
|
||||
# Voice management
|
||||
|
||||
@ -2676,18 +2338,10 @@ class Client:
|
||||
You do not have permissions to move the member.
|
||||
"""
|
||||
|
||||
url = '{0}/{1.server.id}/members/{1.id}'.format(endpoints.SERVERS, member)
|
||||
|
||||
if getattr(channel, 'type', ChannelType.text) != ChannelType.voice:
|
||||
raise InvalidArgument('The channel provided must be a voice channel.')
|
||||
|
||||
payload = utils.to_json({
|
||||
'channel_id': channel.id
|
||||
})
|
||||
response = yield from self.session.patch(url, data=payload, headers=self.headers)
|
||||
log.debug(request_logging_format.format(method='PATCH', response=response))
|
||||
yield from utils._verify_successful_response(response)
|
||||
yield from response.release()
|
||||
yield from self.http.move_member(member.id, member.server.id, channel.id)
|
||||
|
||||
@asyncio.coroutine
|
||||
def join_voice_channel(self, channel):
|
||||
@ -2817,10 +2471,7 @@ class Client:
|
||||
HTTPException
|
||||
Retrieving the information failed somehow.
|
||||
"""
|
||||
url = '{}/@me'.format(endpoints.APPLICATIONS)
|
||||
resp = yield from self.session.get(url, headers=self.headers)
|
||||
yield from utils._verify_successful_response(resp)
|
||||
data = yield from resp.json()
|
||||
data = yield from self.http.application_info()
|
||||
return AppInfo(id=data['id'], name=data['name'],
|
||||
description=data['description'], icon=data['icon'])
|
||||
|
||||
|
@ -40,7 +40,7 @@ import struct
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket',
|
||||
__all__ = [ 'ReconnectWebSocket', 'DiscordWebSocket',
|
||||
'KeepAliveHandler', 'VoiceKeepAliveHandler',
|
||||
'DiscordVoiceWebSocket', 'ResumeWebSocket' ]
|
||||
|
||||
@ -97,36 +97,6 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
|
||||
'd': int(time.time() * 1000)
|
||||
}
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_gateway(token, *, loop=None):
|
||||
"""Returns the gateway URL for connecting to the WebSocket.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
token : str
|
||||
The discord authentication token.
|
||||
loop
|
||||
The event loop.
|
||||
|
||||
Raises
|
||||
------
|
||||
GatewayNotFound
|
||||
When the gateway is not returned gracefully.
|
||||
"""
|
||||
headers = {
|
||||
'authorization': token,
|
||||
'content-type': 'application/json'
|
||||
}
|
||||
|
||||
with aiohttp.ClientSession(loop=loop) as session:
|
||||
resp = yield from session.get(endpoints.GATEWAY, headers=headers)
|
||||
if resp.status != 200:
|
||||
yield from resp.release()
|
||||
raise GatewayNotFound()
|
||||
data = yield from resp.json(encoding='utf-8')
|
||||
return data.get('url') + '?encoding=json&v=4'
|
||||
|
||||
class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
"""Implements a WebSocket for Discord's gateway v4.
|
||||
|
||||
@ -190,11 +160,11 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
|
||||
This is for internal use only.
|
||||
"""
|
||||
gateway = yield from get_gateway(client.token, loop=client.loop)
|
||||
gateway = yield from client.http.get_gateway()
|
||||
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
|
||||
|
||||
# dynamically add attributes needed
|
||||
ws.token = client.token
|
||||
ws.token = client.http.token
|
||||
ws._connection = client.connection
|
||||
ws._dispatch = client.dispatch
|
||||
ws.gateway = gateway
|
||||
@ -505,7 +475,7 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
'server_id': client.guild_id,
|
||||
'user_id': client.user.id,
|
||||
'session_id': client.session_id,
|
||||
'token': client.token
|
||||
'token': client.http.token
|
||||
}
|
||||
}
|
||||
|
||||
|
484
discord/http.py
Normal file
484
discord/http.py
Normal file
@ -0,0 +1,484 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-2016 Rapptz
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import logging
|
||||
import io
|
||||
import inspect
|
||||
import weakref
|
||||
from random import randint as random_integer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from .errors import HTTPException, Forbidden, NotFound, LoginFailure, GatewayNotFound
|
||||
from . import utils, __version__
|
||||
|
||||
@asyncio.coroutine
|
||||
def json_or_text(response):
|
||||
text = yield from response.text(encoding='utf-8')
|
||||
if response.headers['content-type'] == 'application/json':
|
||||
return json.loads(text)
|
||||
return text
|
||||
|
||||
def _func_():
|
||||
# emulate __func__ from C++
|
||||
return inspect.currentframe().f_back.f_code.co_name
|
||||
|
||||
class HTTPClient:
|
||||
"""Represents an HTTP client sending HTTP requests to the Discord API."""
|
||||
|
||||
BASE = 'https://discordapp.com'
|
||||
API_BASE = BASE + '/api'
|
||||
GATEWAY = API_BASE + '/gateway'
|
||||
USERS = API_BASE + '/users'
|
||||
ME = USERS + '/@me'
|
||||
REGISTER = API_BASE + '/auth/register'
|
||||
LOGIN = API_BASE + '/auth/login'
|
||||
LOGOUT = API_BASE + '/auth/logout'
|
||||
GUILDS = API_BASE + '/guilds'
|
||||
CHANNELS = API_BASE + '/channels'
|
||||
APPLICATIONS = API_BASE + '/oauth2/applications'
|
||||
|
||||
SUCCESS_LOG = '{method} {url} with {json} has received {text}'
|
||||
REQUEST_LOG = '{method} {url} has returned {status}'
|
||||
|
||||
def __init__(self, connector=None, *, loop=None):
|
||||
self.loop = asyncio.get_event_loop() if loop is None else loop
|
||||
self.connector = connector
|
||||
self.session = aiohttp.ClientSession(connector=connector, loop=self.loop)
|
||||
self._locks = weakref.WeakValueDictionary()
|
||||
self.token = None
|
||||
self.bot_token = False
|
||||
|
||||
user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
|
||||
self.user_agent = user_agent.format(__version__, sys.version_info, aiohttp.__version__)
|
||||
|
||||
@asyncio.coroutine
|
||||
def request(self, method, url, *, bucket=None, **kwargs):
|
||||
lock = self._locks.get(bucket)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock(loop=self.loop)
|
||||
if bucket is not None:
|
||||
self._locks[bucket] = lock
|
||||
|
||||
# header creation
|
||||
headers = {
|
||||
'User-Agent': self.user_agent,
|
||||
}
|
||||
|
||||
if self.token is not None:
|
||||
headers['Authorization'] = 'Bot ' + self.token if self.bot_token else self.token
|
||||
|
||||
# some checking if it's a JSON request
|
||||
if 'json' in kwargs:
|
||||
headers['Content-Type'] = 'application/json'
|
||||
kwargs['data'] = utils.to_json(kwargs.pop('json'))
|
||||
|
||||
kwargs['headers'] = headers
|
||||
with (yield from lock):
|
||||
for tries in range(5):
|
||||
r = yield from self.session.request(method, url, **kwargs)
|
||||
log.debug(self.REQUEST_LOG.format(method=method, url=url, status=r.status))
|
||||
try:
|
||||
# even errors have text involved in them so this is safe to call
|
||||
data = yield from json_or_text(r)
|
||||
|
||||
# the request was successful so just return the text/json
|
||||
if 300 > r.status >= 200:
|
||||
log.debug(self.SUCCESS_LOG.format(method=method, url=url,
|
||||
json=kwargs.get('data'), text=data))
|
||||
return data
|
||||
|
||||
# we are being rate limited
|
||||
if r.status == 429:
|
||||
fmt = 'We are being rate limited. Retrying in {:.2} seconds. Handled under the bucket "{}"'
|
||||
|
||||
# sleep a bit
|
||||
retry_after = data['retry_after'] / 1000.0
|
||||
log.info(fmt.format(retry_after, bucket))
|
||||
yield from asyncio.sleep(retry_after)
|
||||
continue
|
||||
|
||||
# we've received a 502, unconditional retry
|
||||
if r.status == 502 and tries <= 5:
|
||||
yield from asyncio.sleep(1 + tries * 2)
|
||||
continue
|
||||
|
||||
# the usual error cases
|
||||
if r.status == 403:
|
||||
raise Forbidden(r, data)
|
||||
elif r.status == 404:
|
||||
raise NotFound(r, data)
|
||||
else:
|
||||
raise HTTPException(r, data)
|
||||
finally:
|
||||
# clean-up just in case
|
||||
yield from r.release()
|
||||
|
||||
def get(self, *args, **kwargs):
|
||||
return self.request('GET', *args, **kwargs)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
return self.request('PUT', *args, **kwargs)
|
||||
|
||||
def patch(self, *args, **kwargs):
|
||||
return self.request('PATCH', *args, **kwargs)
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
return self.request('DELETE', *args, **kwargs)
|
||||
|
||||
def post(self, *args, **kwargs):
|
||||
return self.request('POST', *args, **kwargs)
|
||||
|
||||
# state management
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self):
|
||||
yield from self.session.close()
|
||||
|
||||
def recreate(self):
|
||||
self.session = aiohttp.ClientSession(self.connector, loop=self.loop)
|
||||
|
||||
def _token(self, token, *, bot=True):
|
||||
self.token = token
|
||||
self.bot_token = bot
|
||||
|
||||
# login management
|
||||
|
||||
@asyncio.coroutine
|
||||
def email_login(self, email, password):
|
||||
payload = {
|
||||
'email': email,
|
||||
'password': password
|
||||
}
|
||||
|
||||
try:
|
||||
data = yield from self.post(self.LOGIN, json=payload, bucket=_func_())
|
||||
except HTTPException as e:
|
||||
if e.response.status == 400:
|
||||
raise LoginFailure('Improper credentials have been passed.') from e
|
||||
raise
|
||||
|
||||
self._token(data['token'], bot=False)
|
||||
return data
|
||||
|
||||
@asyncio.coroutine
|
||||
def static_login(self, token, *, bot):
|
||||
old_state = (self.token, self.bot_token)
|
||||
self._token(token, bot=bot)
|
||||
|
||||
try:
|
||||
data = yield from self.get(self.ME)
|
||||
except HTTPException as e:
|
||||
self._token(*old_state)
|
||||
if e.response.status == 401:
|
||||
raise LoginFailure('Improper token has been passed.') from e
|
||||
raise e
|
||||
|
||||
return data
|
||||
|
||||
def logout(self):
|
||||
return self.post(self.LOGOUT, bucket=_func_())
|
||||
|
||||
# Message management
|
||||
|
||||
def start_private_message(self, user_id):
|
||||
payload = {
|
||||
'recipient_id': user_id
|
||||
}
|
||||
|
||||
return self.post(self.ME + '/channels', json=payload, bucket=_func_())
|
||||
|
||||
def send_message(self, channel_id, content, *, guild_id=None, tts=False):
|
||||
url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
|
||||
payload = {
|
||||
'content': str(content),
|
||||
'nonce': random_integer(-2**63, 2**63 - 1)
|
||||
}
|
||||
|
||||
if tts:
|
||||
payload['tts'] = True
|
||||
|
||||
return self.post(url, json=payload, bucket='messages:' + str(guild_id))
|
||||
|
||||
def send_typing(self, channel_id):
|
||||
url = '{0.CHANNELS}/{1}/typing'.format(self, channel_id)
|
||||
return self.post(url, bucket=_func_())
|
||||
|
||||
def send_file(self, channel_id, buffer, *, guild_id=None, filename=None, content=None, tts=False):
|
||||
url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
|
||||
form = aiohttp.FormData()
|
||||
|
||||
if content is not None:
|
||||
form.add_field('content', str(content))
|
||||
|
||||
form.add_field('tts', 'true' if tts else 'false')
|
||||
form.add_field('file', io.BytesIO(buffer), filename=filename, content_type='application/octet-stream')
|
||||
|
||||
return self.post(url, data=form, bucket='messages:' + str(guild_id))
|
||||
|
||||
def delete_message(self, channel_id, message_id, guild_id=None):
|
||||
url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id)
|
||||
bucket = '{}:{}'.format(_func_(), guild_id)
|
||||
return self.delete(url, bucket=bucket)
|
||||
|
||||
def delete_messages(self, channel_id, message_ids, guild_id=None):
|
||||
url = '{0.CHANNELS}/{1}/messages/bulk_delete'.format(self, channel_id)
|
||||
payload = {
|
||||
'messages': message_ids
|
||||
}
|
||||
bucket = '{}:{}'.format(_func_(), guild_id)
|
||||
return self.post(url, json=payload, bucket=bucket)
|
||||
|
||||
def edit_message(self, message_id, channel_id, content, *, guild_id=None):
|
||||
url = '{0.CHANNELS}/{1}/messages/{2}'.format(self, channel_id, message_id)
|
||||
payload = {
|
||||
'content': str(content)
|
||||
}
|
||||
return self.patch(url, json=payload, bucket='messages:' + str(guild_id))
|
||||
|
||||
|
||||
def logs_from(self, channel_id, limit, before=None, after=None):
|
||||
url = '{0.CHANNELS}/{1}/messages'.format(self, channel_id)
|
||||
params = {
|
||||
'limit': limit
|
||||
}
|
||||
|
||||
if before:
|
||||
params['before'] = before
|
||||
if after:
|
||||
params['after'] = after
|
||||
|
||||
return self.get(url, params=params, bucket=_func_())
|
||||
|
||||
# Member management
|
||||
|
||||
def kick(self, user_id, guild_id):
|
||||
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
|
||||
return self.delete(url, bucket=_func_())
|
||||
|
||||
def ban(self, user_id, guild_id, delete_message_days=1):
|
||||
url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id)
|
||||
params = {
|
||||
'delete-message-days': delete_message_days
|
||||
}
|
||||
return self.put(url, params=params, bucket=_func_())
|
||||
|
||||
def unban(self, user_id, guild_id):
|
||||
url = '{0.GUILDS}/{1}/bans/{2}'.format(self, guild_id, user_id)
|
||||
return self.delete(url, bucket=_func_())
|
||||
|
||||
def server_voice_state(self, user_id, guild_id, *, mute=False, deafen=False):
|
||||
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
|
||||
payload = {
|
||||
'mute': mute,
|
||||
'deafen': deafen
|
||||
}
|
||||
return self.patch(url, json=payload, bucket='members:' + str(guild_id))
|
||||
|
||||
def edit_profile(self, password, username, avatar, **fields):
|
||||
payload = {
|
||||
'password': password,
|
||||
'username': username,
|
||||
'avatar': avatar
|
||||
}
|
||||
|
||||
if 'email' in fields:
|
||||
payload['email'] = fields['email']
|
||||
|
||||
if 'new_password' in fields:
|
||||
payload['new_password'] = fields['new_password']
|
||||
|
||||
return self.patch(self.ME, json=payload, bucket=_func_())
|
||||
|
||||
def change_my_nickname(self, guild_id, nickname):
|
||||
url = '{0.GUILDS}/{1}/members/@me/nick'.format(self, guild_id)
|
||||
payload = {
|
||||
'nick': nickname
|
||||
}
|
||||
bucket = '{}:{}'.format(_func_(), guild_id)
|
||||
return self.patch(url, json=payload, bucket=bucket)
|
||||
|
||||
def change_nickname(self, guild_id, user_id, nickname):
|
||||
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
|
||||
payload = {
|
||||
'nick': nickname
|
||||
}
|
||||
bucket = '{}:{}'.format(_func_(), guild_id)
|
||||
return self.patch(url, json=payload, bucket=bucket)
|
||||
|
||||
# Channel management
|
||||
|
||||
def edit_channel(self, channel_id, **options):
|
||||
url = '{0.CHANNELS}/{1}'.format(self, channel_id)
|
||||
|
||||
valid_keys = ('name', 'topic', 'bitrate', 'user_limit')
|
||||
payload = {
|
||||
k: v for k, v in options.items() if k in valid_keys
|
||||
}
|
||||
|
||||
return self.patch(url, json=payload, bucket=_func_())
|
||||
|
||||
def create_channel(self, guild_id, name, channe_type):
|
||||
url = '{0.GUILDS}/{1}/channels'.format(self, guild_id)
|
||||
payload = {
|
||||
'name': name,
|
||||
'type': channe_type
|
||||
}
|
||||
|
||||
return self.post(url, json=payload, bucket=_func_())
|
||||
|
||||
def delete_channel(self, channel_id):
|
||||
url = '{0.CHANNELS}/{1}'.format(self, channel_id)
|
||||
return self.delete(url, bucket=_func_())
|
||||
|
||||
# Server management
|
||||
|
||||
def leave_server(self, guild_id):
|
||||
url = '{0.USERS}/@me/guilds/{1}'.format(self, guild_id)
|
||||
return self.delete(url, bucket=_func_())
|
||||
|
||||
def delete_server(self, guild_id):
|
||||
url = '{0.GUILDS}/{1}'.format(self, guild_id)
|
||||
return self.delete(url, bucket=_func_())
|
||||
|
||||
def create_server(self, name, region, icon):
|
||||
payload = {
|
||||
'name': name,
|
||||
'icon': icon,
|
||||
'region': region
|
||||
}
|
||||
|
||||
return self.post(self.GUILDS, json=payload, bucket=_func_())
|
||||
|
||||
def edit_server(self, guild_id, **fields):
|
||||
valid_keys = ('name', 'region', 'icon', 'afk_timeout', 'owner_id',
|
||||
'afk_channel_id', 'splash', 'verification_level')
|
||||
|
||||
payload = {
|
||||
k: v for k, v in fields.items() if k in valid_keys
|
||||
}
|
||||
|
||||
url = '{0.GUILDS}/{1}'.format(self, guild_id)
|
||||
return self.patch(url, json=payload, bucket=_func_())
|
||||
|
||||
def get_bans(self, guild_id):
|
||||
url = '{0.GUILDS}/{1}/bans'.format(self, guild_id)
|
||||
return self.get(url, bucket=_func_())
|
||||
|
||||
# Invite management
|
||||
|
||||
def create_invite(self, channel_id, **options):
|
||||
url = '{0.CHANNELS}/{1}/invites'.format(self, channel_id)
|
||||
payload = {
|
||||
'max_age': options.get('max_age', 0),
|
||||
'max_uses': options.get('max_uses', 0),
|
||||
'temporary': options.get('temporary', False),
|
||||
'xkcdpass': options.get('xkcd', False)
|
||||
}
|
||||
|
||||
return self.post(url, json=payload, bucket=_func_())
|
||||
|
||||
def get_invite(self, invite_id):
|
||||
url = '{0.API_BASE}/invite/{1}'.format(self, invite_id)
|
||||
return self.get(url, bucket=_func_())
|
||||
|
||||
def invites_from(self, guild_id):
|
||||
url = '{0.GUILDS}/{1}/invites'.format(self, guild_id)
|
||||
return self.get(url, bucket=_func_())
|
||||
|
||||
def accept_invite(self, invite_id):
|
||||
url = '{0.API_BASE}/invite/{1}'.format(self, invite_id)
|
||||
return self.post(url, bucket=_func_())
|
||||
|
||||
def delete_invite(self, invite_id):
|
||||
url = '{0.API_BASE}/invite/{1}'.format(self, invite_id)
|
||||
return self.delete(url, bucket=_func_())
|
||||
|
||||
# Role management
|
||||
|
||||
def edit_role(self, guild_id, role_id, **fields):
|
||||
url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id)
|
||||
valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable')
|
||||
payload = {
|
||||
k: v for k, v in fields.items() if k in valid_keys
|
||||
}
|
||||
return self.patch(url, json=payload, bucket='roles:' + str(guild_id))
|
||||
|
||||
def delete_role(self, guild_id, role_id):
|
||||
url = '{0.GUILDS}/{1}/roles/{2}'.format(self, guild_id, role_id)
|
||||
return self.delete(url, bucket=_func_())
|
||||
|
||||
def replace_roles(self, user_id, guild_id, role_ids):
|
||||
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
|
||||
payload = {
|
||||
'roles': role_ids
|
||||
}
|
||||
return self.patch(url, json=payload, bucket='members:' + str(guild_id))
|
||||
|
||||
def create_role(self, guild_id):
|
||||
url = '{0.GUILDS}/{1}/roles'.format(self, guild_id)
|
||||
return self.post(url, bucket=_func_())
|
||||
|
||||
def edit_channel_permissions(self, channel_id, target, allow, deny, type):
|
||||
url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target)
|
||||
payload = {
|
||||
'id': target,
|
||||
'allow': allow,
|
||||
'deny': deny,
|
||||
'type': type
|
||||
}
|
||||
return self.put(url, json=payload, bucket=_func_())
|
||||
|
||||
def delete_channel_permissions(self, channel_id, target):
|
||||
url = '{0.CHANNELS}/{1}/permissions/{2}'.format(self, channel_id, target)
|
||||
return self.delete(url, bucket=_func_())
|
||||
|
||||
# Voice management
|
||||
|
||||
def move_member(self, user_id, guild_id, channel_id):
|
||||
url = '{0.GUILDS}/{1}/members/{2}'.format(self, guild_id, user_id)
|
||||
payload = {
|
||||
'channel_id': channel_id
|
||||
}
|
||||
return self.patch(url, json=payload, bucket='members:' + str(guild_id))
|
||||
|
||||
# Misc
|
||||
|
||||
def application_info(self):
|
||||
url = '{0.APPLICATIONS}/@me'.format(self)
|
||||
return self.get(url, bucket=_func_())
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_gateway(self):
|
||||
try:
|
||||
data = yield from self.get(self.GATEWAY, bucket=_func_())
|
||||
except HTTPException as e:
|
||||
raise GatewayNotFound() from e
|
||||
return data.get('url') + '?encoding=json&v=4'
|
Loading…
x
Reference in New Issue
Block a user