280 Commits

Author SHA1 Message Date
Daishiky
4291d4cf31 Update discord/colour.py
Co-authored-by: proguy914629 <74696067+proguy914629bot@users.noreply.github.com>
2021-09-20 19:06:53 +02:00
Daishiky
28caab0974 Update discord/colour.py
Co-authored-by: proguy914629 <74696067+proguy914629bot@users.noreply.github.com>
2021-09-20 19:06:47 +02:00
Daishiky
f9e0e2b55a Update colour.py 2021-09-19 18:54:10 +02:00
Daishiky
d2aae4752a added embed_background color 2021-09-19 18:51:02 +02:00
Tom
9356e385d8 Merge pull request #2 from WhoTheOOF/patch-2
fix my dumb <>
2021-08-27 19:14:51 -07:00
Jadon
6feab9abba fix my dumb <> 2021-08-27 21:11:40 -05:00
Tom
3bbb1187c9 Merge pull request #1 from WhoTheOOF/patch-1
update links to new links
2021-08-27 19:09:24 -07:00
Jadon
cb3869f39c update links to new links 2021-08-27 21:07:45 -05:00
iDutchy
2b66f2288c well, that was another failed idea... 2021-07-13 20:06:54 -05:00
iDutchy
00619fc6cd Making progress, almost there... 2021-07-13 19:58:09 -05:00
iDutchy
3056c6f0f4 new attempr 2021-07-13 19:53:19 -05:00
iDutchy
d7a0b0af04 oops 2021-07-13 19:37:14 -05:00
iDutchy
f8bf64c9b7 that failed... 2021-07-13 19:32:30 -05:00
iDutchy
8e6601c2c5 first attempt to shortcuts 2021-07-13 19:28:52 -05:00
iDutchy
2df88ccc27 Apparently a conflict sneaked through 2021-06-25 21:07:02 -05:00
iDutchy
0f3b15f10d Merge branch 'Rapptz-v1.x' 2021-06-25 20:34:02 -05:00
iDutchy
2a32e56873 fix conflict 2021-06-25 20:33:22 -05:00
iDutchy
34f6c5db10 looks like I changed stuff? 2021-06-25 20:29:27 -05:00
Rapptz
462ba84809 Version bump to v1.7.3 2021-06-12 12:09:54 -04:00
Rapptz
9ff979572a Add changelog for v1.7.3 2021-06-12 12:09:38 -04:00
Rapptz
9376fcd69d Fix crash involving stickers 2021-06-12 12:04:25 -04:00
Rapptz
e79321c032 Fix DM channel permissions not having read_messages 2021-06-12 12:01:53 -04:00
Nadir Chowdhury
81457330ed [docs] typo fix 2021-06-12 12:01:14 -04:00
Alex Nørgaard
72dd2381b0 Update docs for (Partial)Message.publish to reflect the actual permissions needed 2021-06-12 12:01:01 -04:00
ChasL
f0c3568ea9 Fix for doc reference to python "raise" statement
:ref:`py:raise` -> :ref:`raise statement <py:raise>`

Before fix the text reads: "...define an on_error handler consisting
of a single empty The raise statement." After fix it should read: 
"...define an on_error handler consisting of a single empty raise
statement."
2021-06-12 12:00:33 -04:00
Rapptz
6a52eea6ff Fix AuditLogEntry.target being incorrect for bulk message delete
Fixes #6851
2021-05-05 00:18:44 -04:00
Rapptz
9d2576e464 Version bump to v1.7.2 2021-05-02 23:35:53 -04:00
Rapptz
69fdf8c67d Add changelog for v1.7.2 2021-05-02 23:35:25 -04:00
iDutchy
917649b1b2 fix conflicts 2021-05-02 21:05:05 -05:00
David
f6c8bfdf0d Clarify ClientUser.verified docs 2021-05-01 10:33:01 -04:00
MrKomodoDragon
5e7357efa5 Fix grammar in the Guild.edit docstring 2021-05-01 10:31:44 -04:00
pikaninja
318166d875 [docs] Update notes for get_user and get_member 2021-05-01 10:31:27 -04:00
HyperGH
648b786bc1 Adjust quickstart to not show commands example 2021-05-01 10:31:00 -04:00
jack1142
ccf7e65c50 Fix documentation for RoleConverter 2021-05-01 10:30:27 -04:00
Kino
eb1d03f8f7 Fix typo within HelpCommand.verify_checks documentation 2021-05-01 10:30:02 -04:00
Nadir Chowdhury
5b98ce1235 [docs] fix docstring of AppInfo 2021-05-01 10:29:31 -04:00
Cryptex
06184dc25f Update lavalink's repo url 2021-05-01 10:28:28 -04:00
Steve C
a9dba2753f [tasks] Move the Loop's sleep to be before exit conditions
This change makes it more so that `Loop.stop()` gracefully makes the
current iteration the final one, by waiting AND THEN returning.
The current implementation is closer to `cancel`, while also not.

I encountered this because I was trying to run a
`@tasks.loop(count=1)`, and inside it I print some text and change the
interval, and in an `after_loop`, I restart the loop.

Without this change, it immediately floods my console, due to
not waiting before executing `after_loop`.
2021-05-01 10:28:00 -04:00
NoName
87dd046c32 Add periods to sticker docs 2021-05-01 10:27:12 -04:00
Nadir Chowdhury
7d37c3a506 [docs] Fix various unresolved references 2021-05-01 10:26:11 -04:00
Steve C
4d47436b02 Fix guild.chunk() not working on evicted guilds
If you're trying to chunk a guild that the bot is not in, 
it'll just hang on the chunk coro forever. It's weird, I know.
2021-05-01 10:24:40 -04:00
numbermaniac
f50877c9b8 Add note to member docs about Spotify limitation 2021-05-01 10:24:07 -04:00
Maya
afca943f16 Fix exception for invalid channel types 2021-05-01 10:22:49 -04:00
Dan Hess
81e3f58f43 Fix async adapter detection in message deletion 2021-04-29 19:21:33 -04:00
iDutchy
523e35e4f3 Merge pull request #17 from TheMoksej/patch-2
add silent feature back to delete() function
2021-04-17 16:57:46 +02:00
Moksej
93da1d920e didn't mean to remove this 2021-04-17 16:47:15 +02:00
Moksej
cdfd918604 add silent feature back to delete() function
latest commit removed the silent kwarg from the delete function so I'm adding it back
2021-04-17 16:46:06 +02:00
Rapptz
f1130e4985 Fix fail_if_not_exists not being set when constructed with state 2021-04-15 09:00:51 -04:00
pikaninja
187b811836 Add StageChannel to abc.GuildChannel docs 2021-04-15 09:00:09 -04:00
Kino
ff04cab119 [docs] Fix reference to Guild.id 2021-04-15 09:00:09 -04:00
Rapptz
65205a8e39 Fix Intents resolution in the docs 2021-04-15 09:00:09 -04:00
Rapptz
28934001b2 Bring back discord module in discord.ext.commands documentation 2021-04-15 09:00:09 -04:00
Rapptz
af7d93725d Remove current module reference in commands API docs 2021-04-15 09:00:09 -04:00
jack1142
efa6482ac4 Add missing attribute tables 2021-04-15 08:58:32 -04:00
Rapptz
4756485ea4 [commands] Add StageChannelConverter to documentation 2021-04-15 08:58:00 -04:00
Shun Tannai
8b4d7f122c [commands] Update Converter list in ext.commands introduction 2021-04-15 08:57:45 -04:00
Sebastian Law
0d1cf57f62 [docs] add note for possible Embed.type strings 2021-04-15 08:57:25 -04:00
Sebastian Law
3ad795ab6a [docs] add label to basic converters section 2021-04-15 08:57:08 -04:00
iDutchy
a453266cd4 conflict fixes 2021-04-07 18:27:34 -05:00
Rapptz
8517f1e085 Version bump to v1.7.1 2021-04-05 21:21:41 -04:00
Rapptz
0a4be8f83c Update changelog for v1.7.1 2021-04-05 21:21:41 -04:00
Logan
56faa98e4b Fix versionadded not showing in docs for Attachment.content_type 2021-04-05 21:10:55 -04:00
Rapptz
1b2c527fc7 [commands] Fix logic in Cog.has_error_handler() 2021-04-05 21:08:54 -04:00
iDutchy
f1309aa4a1 added Client.get_message 2021-02-11 18:38:33 -06:00
iDutchy
86fd3fb738 conflict fixes 2021-01-14 18:03:09 -06:00
iDutchy
694e5e2861 guess I missed this one 2021-01-12 22:33:47 -06:00
iDutchy
54b5e253c9 another small fix 2021-01-12 22:30:55 -06:00
iDutchy
493fc88d6d fix some docs issues 2021-01-12 22:16:09 -06:00
iDutchy
732cab7c5c Works, so added to dos. + little cleanup 2021-01-12 21:57:23 -06:00
iDutchy
2c650feb98 cant support aiohttp 3.8 yet 2021-01-12 21:41:51 -06:00
iDutchy
316e74a9ec test for silent message deleting 2021-01-12 21:39:41 -06:00
iDutchy
89d2f00911 so eh, lets comment this out for now since using it crashes the machine... 2021-01-05 19:46:05 -06:00
iDutchy
a0d491c71f fix machine crashes? 2021-01-05 19:15:24 -06:00
iDutchy
be0933c928 broke case insensitive prefixes... Oops 2021-01-05 18:57:12 -06:00
iDutchy
db9dd93ad4 docs update + add case_insensitive_prefix 2021-01-05 18:20:23 -06:00
iDutchy
b50b8e903f update docs 2020-12-30 22:19:26 -06:00
iDutchy
57e6c946c9 Merge branch 'master' of https://github.com/iDutchy/discord.py 2020-12-30 18:52:30 -06:00
iDutchy
f4bec507c1 attempt at cog aliases 2020-12-30 18:52:15 -06:00
iDutchy
e090eb66d1 Merge pull request #12 from averwhy/master
fix typo in custom_features
2020-12-05 02:07:45 +01:00
iDutchy
0354036451 fix conflicts 2020-12-04 19:05:58 -06:00
iDutchy
9f54345f5c idk, first attempt at adding docs ig. Probably gonna break so yea... 2020-12-01 02:25:51 -06:00
averwhy
4f3d489135 fix typo in custom_features 2020-11-27 20:22:19 -05:00
iDutchy
6e024871ec fix merge conflict 2020-11-24 17:19:31 -06:00
iDutchy
777c95aab2 update 2020-11-24 17:17:25 -06:00
iDutchy
b058b4730c update changelog 2020-11-19 18:48:07 -06:00
iDutchy
725f08e45d add Color.random 2020-11-19 18:46:42 -06:00
iDutchy
b61b5b7414 type fix 2020-11-18 18:09:18 -06:00
iDutchy
e47ff96c30 docs fix 2020-11-18 18:04:33 -06:00
iDutchy
a2b513bd72 docs fix 2020-11-18 18:02:05 -06:00
iDutchy
9e6461a419 docs fix 2020-11-18 17:57:32 -06:00
iDutchy
690dcdaf2e update docs and add message replies 2020-11-18 17:49:58 -06:00
iDutchy
195bace135 admin alias 2020-11-18 16:41:38 -06:00
iDutchy
ae0f11ce53 add a color 2020-11-04 19:51:07 -06:00
iDutchy
188b69c097 docs update 2020-10-28 21:04:30 -05:00
iDutchy
dea09cb5b3 conflict fix 2020-10-28 21:00:48 -05:00
iDutchy
c223d2e723 better prefix 2020-10-28 20:49:49 -05:00
iDutchy
12de975b69 better prefix 2020-10-28 19:49:13 -05:00
iDutchy
14d8310192 properly checking TextChannel.can_send 2020-10-21 20:06:47 -05:00
iDutchy
c6b417bc7b version bump 2020-10-21 17:54:07 -05:00
iDutchy
3521ae985a added versionadded 2020-10-21 17:44:46 -05:00
iDutchy
2b5490d4cb fixes 2020-10-17 20:22:16 -05:00
iDutchy
18f80a737f hmm 2020-10-17 20:09:06 -05:00
iDutchy
faa566040c final step 2020-10-17 20:04:37 -05:00
iDutchy
2d7b6e239b preparing 2020-10-17 20:03:53 -05:00
iDutchy
09168d880f int() support 2020-10-17 19:14:51 -05:00
iDutchy
24839be99d test 2020-10-17 19:07:04 -05:00
iDutchy
4010f09052 Remove int() support 2020-10-17 18:45:59 -05:00
iDutchy
b9642f785e Remove int() support 2020-10-17 18:45:51 -05:00
iDutchy
d75cd66b90 oh ffs 2020-10-17 18:42:19 -05:00
iDutchy
a09e096d42 add menus 2020-10-08 16:06:18 -05:00
iDutchy
0b8671e3d6 doc update 2020-10-04 02:05:24 +00:00
iDutchy
9e0303cc53 Extra security for not overriding bot.embed_color 2020-10-04 02:03:55 +00:00
iDutchy
e3bce1ba58 add doc 2020-10-04 01:58:06 +00:00
iDutchy
ecd898e62c oop 2020-10-04 01:53:57 +00:00
iDutchy
f96a537b8f oop 2020-10-04 01:11:49 +00:00
iDutchy
36bcbb19ee oop 2020-10-04 01:01:17 +00:00
iDutchy
9fc2fd38dc Another request from shivaco 2020-10-04 00:39:03 +00:00
iDutchy
ea73008ff2 update docs 2020-10-02 00:34:53 +00:00
iDutchy
525ee4be0a update docs 2020-10-02 00:34:01 +00:00
iDutchy
7a56f0b28a requested by shivaco ;) 2020-10-02 00:23:25 +00:00
iDutchy
60d383cb51 oop 2020-10-01 22:45:57 +00:00
iDutchy
4297eed591 hmm 2020-10-01 22:41:02 +00:00
iDutchy
0ed9d8ca6b oop 2020-10-01 22:27:07 +00:00
iDutchy
b5ef2bdec4 oop 2020-10-01 22:24:15 +00:00
iDutchy
768f409a84 hmm 2020-10-01 22:22:53 +00:00
iDutchy
7afaa6dfce hmm 2020-10-01 22:16:20 +00:00
iDutchy
21ea6fe9ac hmm 2020-10-01 22:16:05 +00:00
iDutchy
291237bac9 hmm 2020-10-01 22:11:25 +00:00
iDutchy
397535f1e5 hmm 2020-10-01 22:05:34 +00:00
iDutchy
b13eca9def hmm 2020-10-01 22:01:20 +00:00
iDutchy
e6cf6c4b8c hmm 2020-10-01 21:58:29 +00:00
iDutchy
e91a0d62f7 hmm 2020-10-01 21:55:05 +00:00
iDutchy
0860df8fa5 hmm 2020-10-01 21:53:49 +00:00
iDutchy
1c228f9548 oop 2020-10-01 21:43:22 +00:00
iDutchy
0088ab589b oop 2020-10-01 21:42:00 +00:00
iDutchy
d79bc7c3c9 support for default embed color 2020-10-01 21:37:37 +00:00
iDutchy
5b2c7db90b forgot 2020-10-01 02:27:11 +00:00
iDutchy
cd6f48b39c forgot this 2020-10-01 01:28:20 +00:00
iDutchy
74713b05ee add try_user 2020-10-01 01:24:30 +00:00
iDutchy
7cb96f7ba4 oop again 2020-10-01 01:02:28 +00:00
iDutchy
fe826b7134 changes 2020-10-01 01:01:18 +00:00
iDutchy
db6a7d46a1 oop 2020-10-01 00:34:59 +00:00
iDutchy
fd9ceb30f2 I suck 2020-09-30 23:59:06 +00:00
iDutchy
f514d45f99 oops 2020-09-30 23:56:50 +00:00
iDutchy
57efed682b update docs 2020-09-30 23:54:25 +00:00
iDutchy
f33cfbce0d docs update 2020-09-30 23:49:33 +00:00
iDutchy
e21fb1217e Merge pull request #8 from iDutchy/owner
add Bot.owners
2020-10-01 01:28:26 +02:00
iDutchy
73ed64c527 add Bot.owners 2020-10-01 01:24:25 +02:00
iDutchy
7356a641b5 Merge pull request #7 from iDutchy/owner
add Bot.owner
2020-10-01 01:22:12 +02:00
iDutchy
447a6a694e add Bot.owner 2020-10-01 01:20:34 +02:00
iDutchy
7af9f2af94 Merge branch 'Rapptz-neo-docs' 2020-09-30 22:42:38 +00:00
iDutchy
f80b4c166c conflict fix 2020-09-30 22:42:28 +00:00
iDutchy
b98fc6f2f6 Merge pull request #5 from Rapptz/master
1.5 release
2020-10-01 00:18:26 +02:00
iDutchy
c8bd6884dc Merge branch 'Rapptz-master' 2020-09-28 00:42:59 +00:00
iDutchy
64be57b192 fixes 2020-09-28 00:42:27 +00:00
Josh
3cc5e23392 Set maximimum sidebar width 2020-09-23 02:28:17 -04:00
Muhammad Hamza
ba7482921b [matrix] Style "View Documentation For" dropdown 2020-09-22 20:29:04 -04:00
iDutchy
2774cfd3e9 welp, wasnt aware __hex__ got removed in py3 2020-09-18 01:04:04 +00:00
iDutchy
bbaf3375a8 a fix I think? 2020-09-18 00:54:26 +00:00
iDutchy
d28f0ff35b add hex() support to Color 2020-09-18 00:44:40 +00:00
iDutchy
b2540ee312 Lets add some color! 2020-09-18 00:34:29 +00:00
iDutchy
a67bb723b4 competing type added 2020-09-15 20:34:43 +00:00
iDutchy
5756548a6a Merge branch 'Rapptz-feature/intents' 2020-09-15 00:37:04 +00:00
iDutchy
571ddb5a3e merge conflict fix 2020-09-15 00:36:19 +00:00
Rapptz
6546f63ad7 Add a special exception for required privileged intents 2020-09-14 03:49:21 -04:00
Rapptz
4c56e6da9c Pass default intents if not explicitly given 2020-09-14 03:20:41 -04:00
Rapptz
27b224778b Intern status and overwrite strings 2020-09-14 03:20:41 -04:00
Rapptz
ab049e3eb0 Allow finer grained control over the member cache. 2020-09-14 03:20:36 -04:00
Rapptz
6f22ba8ad0 Raise if member intent is not enabled 2020-09-14 03:20:17 -04:00
Rapptz
f3514a4d53 Don't cache members during guild start up if cache is disabled.
This is mainly a half-implemented commit. There are a few more places
where cache consistency is necessary. In the future there will
probably be a member cache policy enum that will be used and cache
consistency will be tackled in part of that larger refactoring.
2020-09-14 03:20:17 -04:00
Rapptz
141511471e Add Guild.chunk and deprecated Client.request_offline_members 2020-09-14 03:20:16 -04:00
Rapptz
27558ec71a Fix Client.request_offline_members no longer working 2020-09-14 03:20:16 -04:00
Rapptz
a6edb66742 Add versionadded for intents enum 2020-09-14 03:20:16 -04:00
Rapptz
b1de57f299 Explicitly disable the members presence by default 2020-09-14 03:20:16 -04:00
Rapptz
0fc8ac6f80 Fix timeouts due to hitting the gateway rate limit 2020-09-14 03:20:16 -04:00
Rapptz
18141c0cf9 Maximize the amount of concurrency while chunking.
In order to reduce our amount of backpressure we need to limit the
amount of concurrent chunk requests we can have so the gateway buffer
has some time to breathe.
2020-09-14 03:20:16 -04:00
Rapptz
022ec9af1d Check for zombie connections through last received payload
The previous code would check zombie connections depending on whether
HEARTBEAT_ACK was received. Unfortunately when there's exceeding
backpressure the connection can terminate since the HEARTBEAT_ACK is
buffered very far away despite it being there, just not received yet.
2020-09-14 03:20:16 -04:00
Rapptz
9492cb1242 Speed up chunking for guilds with presence intent enabled 2020-09-14 03:20:16 -04:00
Rapptz
a76f9ce8ef Maximize concurrency when chunking on AutoSharded clients 2020-09-14 03:20:16 -04:00
Rapptz
faf1db1583 Use a lock for the gateway rate limiter.
This will allow for higher concurrency in AutoSharded situations where
I can mostly "fire and forget" the chunk requests.
2020-09-14 03:20:15 -04:00
Rapptz
d6defbc6b2 Heartbeats bypass the rate limits for gateway 2020-09-14 03:20:15 -04:00
Rapptz
5db9a3551f All guilds require chunking if opting into it 2020-09-14 03:20:15 -04:00
Rapptz
e8e4886fd8 Handle user updates within GUILD_MEMBER_UPDATE 2020-09-14 03:20:15 -04:00
Rapptz
95bec0dcee Rewrite chunking to work with intents.
This slows down chunking significantly for bots in a large number of
guilds since it goes down from 75 guilds/request to 1 guild/request.
However the logic was rewritten to fire the chunking request
immediately after receiving the GUILD_CREATE rather than waiting for
all the guilds in the ready stream before doing it.
2020-09-14 03:20:15 -04:00
Rapptz
f46257faa6 Add more close codes that can't be handled for reconnecting. 2020-09-14 03:20:15 -04:00
Rapptz
c0a3aaa98c Change unknown cache log warnings from WARNING -> DEBUG 2020-09-14 03:20:15 -04:00
Rapptz
75c24bde16 Handle gateway rate limits by using a rate limiter.
With the new chunking changes this will become necessary and we don't
want to disconnect from having too many outwards requests.
2020-09-14 03:20:15 -04:00
Rapptz
a9cb851a3c Add support for guild intents 2020-09-14 03:20:15 -04:00
iDutchy
43c4d33a4a avatar urls in stead of assets 2020-09-13 23:57:18 +00:00
iDutchy
4b612aeece Merge pull request #1 from Rapptz/feature/intents
Feature/intents
2020-09-13 05:44:24 +02:00
iDutchy
1791b72f45 Add support for Guild.bots and Guild.humans 2020-09-13 01:17:35 +00:00
Rapptz
77b0ddca7c Raise if member intent is not enabled 2020-09-10 06:46:16 -04:00
Rapptz
61ec62da11 Don't cache members during guild start up if cache is disabled.
This is mainly a half-implemented commit. There are a few more places
where cache consistency is necessary. In the future there will
probably be a member cache policy enum that will be used and cache
consistency will be tackled in part of that larger refactoring.
2020-09-10 05:58:24 -04:00
Rapptz
009a961006 Add Guild.chunk and deprecated Client.request_offline_members 2020-09-10 05:56:48 -04:00
Rapptz
cb211c36bd Fix Client.request_offline_members no longer working 2020-09-10 05:26:35 -04:00
Rapptz
a293d87c77 Add versionadded for intents enum 2020-09-10 05:17:52 -04:00
Rapptz
41fd2740cb Explicitly disable the members presence by default 2020-09-10 05:17:52 -04:00
Rapptz
65f591705d Fix timeouts due to hitting the gateway rate limit 2020-09-10 05:17:52 -04:00
Rapptz
81bfdea9df Maximize the amount of concurrency while chunking.
In order to reduce our amount of backpressure we need to limit the
amount of concurrent chunk requests we can have so the gateway buffer
has some time to breathe.
2020-09-10 05:17:52 -04:00
Rapptz
2129ae29be Check for zombie connections through last received payload
The previous code would check zombie connections depending on whether
HEARTBEAT_ACK was received. Unfortunately when there's exceeding
backpressure the connection can terminate since the HEARTBEAT_ACK is
buffered very far away despite it being there, just not received yet.
2020-09-10 05:17:52 -04:00
Rapptz
82fa967f3c Speed up chunking for guilds with presence intent enabled 2020-09-10 05:17:51 -04:00
Rapptz
fdbe0c4f57 Maximize concurrency when chunking on AutoSharded clients 2020-09-10 05:17:51 -04:00
Rapptz
5837ad0804 Use a lock for the gateway rate limiter.
This will allow for higher concurrency in AutoSharded situations where
I can mostly "fire and forget" the chunk requests.
2020-09-10 05:17:51 -04:00
Rapptz
e6fddbdbe7 Heartbeats bypass the rate limits for gateway 2020-09-10 05:17:51 -04:00
Rapptz
37760e16dd All guilds require chunking if opting into it 2020-09-10 05:17:51 -04:00
Rapptz
fd5faac42b Handle user updates within GUILD_MEMBER_UPDATE 2020-09-10 05:17:50 -04:00
Rapptz
eb641569f7 Rewrite chunking to work with intents.
This slows down chunking significantly for bots in a large number of
guilds since it goes down from 75 guilds/request to 1 guild/request.
However the logic was rewritten to fire the chunking request
immediately after receiving the GUILD_CREATE rather than waiting for
all the guilds in the ready stream before doing it.
2020-09-10 05:17:50 -04:00
Rapptz
51704b10cb Add more close codes that can't be handled for reconnecting. 2020-09-10 05:17:50 -04:00
Rapptz
50a951e3ec Change unknown cache log warnings from WARNING -> DEBUG 2020-09-10 05:17:50 -04:00
Rapptz
63c454eaa0 Handle gateway rate limits by using a rate limiter.
With the new chunking changes this will become necessary and we don't
want to disconnect from having too many outwards requests.
2020-09-10 05:17:50 -04:00
Rapptz
f588834b0c Add support for guild intents 2020-09-10 05:17:50 -04:00
Myst(MysterialPy)
a668623d9f Make admonition-title un-selectable.
All admonition-title's should now be un-selectable.
2020-09-06 00:24:14 -04:00
Nihaal Sangha
0b020fc339 Add sidebar animation when collapsing 2020-09-01 15:24:03 -04:00
Rapptz
0124abb030 Thicken admonition borders a little 2020-08-31 02:35:38 -04:00
Josh
26cce4fb78 [matrix] Hide hamburger menu on pages without sidebar 2020-08-31 02:27:32 -04:00
Josh
3b90e2e74e [matrix] Fix JS errors on search results page 2020-08-30 02:55:53 -04:00
jack1142
512d9aaccb Another take at fixing methods showing up under "Attributes" 2020-08-29 22:22:45 -04:00
Rapptz
39f1f9098e Fix collapsible sidebar not working 2020-08-29 20:39:25 -04:00
Rapptz
994de512cb Use the constructed value in the settings 2020-08-29 20:34:50 -04:00
Muhammad Hamza
597f7e30b8 [matrix] Update model styles 2020-08-29 20:19:49 -04:00
Rapptz
575435b4c9 Fix tooltips in settings and make strings translatable 2020-08-29 20:10:54 -04:00
Rapptz
7d8dae735d Move setting load to DOMContentLoaded 2020-08-29 19:56:28 -04:00
Josh
3ce7ab2fc4 [matrix] Refactor JS & add searchbar to mobile. 2020-08-29 19:17:44 -04:00
Rapptz
2d441cc533 Reduce CSS variable usage 2020-08-29 04:30:35 -04:00
Rapptz
7fec153cd7 Fix versionmodified not being italics 2020-08-29 04:14:08 -04:00
Rapptz
1aa93e70ac Change colour scheme and admonition colours
This should make both themes finally look decent
2020-08-29 04:11:05 -04:00
Josh
42498d26f7 [matrix] Set theme to system preferred by default 2020-08-29 03:57:17 -04:00
Nadir Chowdhury
a9d6d90a8f [matrix] collapsible sidebar headings 2020-08-28 23:13:20 -04:00
jack1142
d9a2c0c65d Fix methods from superclass showing under "Attributes" table 2020-08-22 16:26:50 -04:00
Josh
9cbb801fb0 Fix sidebar jank on desktop. 2020-07-22 23:35:51 -04:00
Josh
41153d6d90 Fix issues with horizontal overflow on mobile 2020-07-08 23:23:52 -04:00
Josh
b2b2d5ac96 Default to sans-serif font 2020-07-08 22:54:23 -04:00
Rapptz
c2a46f3b8b Redesign admonitions to look a little better.
Colours still need to be reworked though.
2020-07-01 03:42:58 -04:00
Josh
a53bf2660b [matrix] Display navbar links on mobile 2020-06-29 19:47:15 -04:00
James
c928fd13f1 Resize favicon to 256x256px for Chrome 2020-06-28 19:15:13 -04:00
Rapptz
597af3a582 Switch icon set over to Material Icons intead of FontAwesome 2020-06-28 18:21:37 -04:00
Rapptz
4ebbeb0f2a Rework attributetable to look prettier 2020-06-28 07:34:04 -04:00
Josh
2a8453828b Fix sidebar scrolling on mobile 2020-06-28 17:42:14 +10:00
Rapptz
7482a5de8d Refactor CSS to use a colour palette and make light theme greyer. 2020-06-28 03:36:59 -04:00
Rapptz
c69f7c7bd8 Make tables scroll if they overflow. 2020-06-27 07:55:47 -04:00
Rapptz
8feb74a018 Revert "Fix table wrapping"
This reverts commit c911cd0dbd.
2020-06-27 07:55:47 -04:00
James
69e2cd0180 Add border radius and padding to inline code 2020-06-27 12:33:51 +01:00
Rapptz
c911cd0dbd Fix table wrapping 2020-06-27 07:08:46 -04:00
Rapptz
f4d53d79df Fix margins in 600px view of settings and label 2020-06-27 06:30:38 -04:00
Josh
f1e9017df1 Fix jank on iPads 2020-06-27 19:42:25 +10:00
Rapptz
0a00aeb335 Show classmethods separately in attribute table 2020-06-27 02:25:50 -04:00
Rapptz
6eba27d98e Alphabetically sort attributetable output 2020-06-27 02:05:23 -04:00
Rapptz
7dd45a422c Show the search bar on mobile 2020-06-27 01:53:41 -04:00
Josh
2ef0695e81 [matrix] General Sidebar cleanup (#5061) 2020-06-27 01:16:37 -04:00
Rapptz
8abd4e1357 Various RTD related fixes. 2020-06-25 03:57:58 -04:00
Josh B
5cb1b109bb Set colours for active sidebar elements 2020-06-08 19:21:44 +10:00
Josh B
3c56240e5f Fix sidebar active link selection 2020-06-01 00:38:37 +10:00
Rapptz
90596485a2 First pass at double header display 2020-05-31 09:12:26 -04:00
Josh B
b78f6a310b Create settings icon for mobile 2020-05-31 00:11:03 +10:00
Rapptz
74bdd8485e Use new HTML5 <section> instead of <div class="section"> 2020-05-30 04:59:31 -04:00
Jens Reidel
f03ecdbc69 [matrix] Search to top, icon
* Search bar to top, magnifying glass

* Remove old file

* Remove empty style directive
2020-05-29 23:42:50 -04:00
Rapptz
d14bf7f412 First pass at centering content for large displays 2020-05-29 09:34:21 -04:00
Rapptz
742b14a705 Add dark theme for codeblocks 2020-05-29 07:23:00 -04:00
Rapptz
71f6b950d1 Actually make overflowing have a scrollbar on mobile 2020-05-29 06:45:44 -04:00
Rapptz
8a94adcbcd Fix codeblock related things with mobile responsiveness. 2020-05-29 06:21:05 -04:00
Josh
a31cf94699 Use default scrollbar for body on webkit browsers 2020-05-29 03:52:56 -04:00
Josh
dc545f570e [matrix] Modal cleanup
* General modal cleanup

* Remove second scrollbar caused by modal
2020-05-29 03:25:13 -04:00
Josh
24c9e7b5fc [matrix] Dark Theme
* Apply width restructions to modals and images

* Dark theme 2.0

* Add webkit scrollbar

* Use Object.keys instead of Object.entries where applicable
2020-05-29 02:57:00 -04:00
Rapptz
38529e6e21 Proper padding for the copy button 2020-05-28 02:21:01 -04:00
Rapptz
439081081c Reverse the related links 2020-05-28 01:33:16 -04:00
Rapptz
aedd40e585 Use html_js_files instead of the old approach to add JS files. 2020-05-28 01:07:17 -04:00
Rapptz
da4e345f3d Cleanup copy button CSS and add a hover-over explanation. 2020-05-28 01:06:06 -04:00
Rapptz
4e9fdc6e4f Rewrite the DOM to use CSS grids
This also rewrites the CSS to use CSS variables. Currently this isn't
done to codeblocks however.
2020-05-27 23:43:58 -04:00
NCPlayz
0a8b87cae7 add copy codeblock button
Apply suggestions from code review

Co-authored-by: Danny <Rapptz@users.noreply.github.com>

Change to icon, change according to slice's review
2020-05-27 23:39:11 -04:00
Josh
38a7cbb6a5 [matrix] Add sans-serif font toggle to settings modal
* Add sans serif font toggle

* remove unnecessary boolean comparison from setFont

Co-authored-by: slice <ryaneft@gmail.com>

* Update checkbox title

Co-authored-by: slice <ryaneft@gmail.com>

* General cleanup of settings system

* Apply overflow hidden to modal

Co-authored-by: slice <ryaneft@gmail.com>
2020-05-27 10:05:40 -04:00
Josh
e6712d76d1 [matrix] Create settings modal
* Create settings modal

* Fix issue with spacing after settings button

* Fix issue with modal background on mobile devices

* Add close button to modal

* Add tooltip to close button

* Support closing modal with escape key

* Add missing semicolon to keydown event listener
2020-05-27 02:22:21 -04:00
Josh
1e471b64e6 [matrix] Refactor docs JS
* Refactor custom.js

* Refactor scorer.js

* tables variable shoudn't be in global scope
2020-05-27 00:56:38 -04:00
Nadir Chowdhury
509cc135d4 Add favicon 2020-05-26 23:18:42 -04:00
Rapptz
04cec0ec10 Use actual viewport tag with initial-scale set to 1 2020-05-26 07:04:58 -04:00
Jens Reidel
f2482d4fb3 Add fixed header links, fix some parts of mobile UI
Dynamic content width equal to old one if on 1080p

Fix mobile view

Disable fixed header on mobile
2020-05-25 22:37:01 -04:00
Rapptz
ccb4e0b6e7 Bump Sphinx to 3.0.3 2020-05-25 22:15:46 -04:00
Riley Shaw
3c558af0cb make documentation sphinx 3.x compatible 2020-05-25 21:39:59 -04:00
Rapptz
2eb9e3bc56 Move table JS outside of scrolling 2020-05-25 12:17:13 -04:00
Rapptz
de9a3b5f60 Bump Sphinx to 2.4.4 2020-05-25 11:55:13 -04:00
Rapptz
771c1c85d8 Add attributetable and add some class-level sections.
The extensions have yet to receive this treatment and CSS needs work,
but for now this is fine.
2020-05-25 11:48:16 -04:00
172 changed files with 49483 additions and 23238 deletions

View File

@@ -1,6 +1,7 @@
name: Bug Report name: Bug Report
description: Report broken or incorrect behaviour description: Report broken or incorrect behaviour
labels: unconfirmed bug labels: unconfirmed bug
issue_body: true
body: body:
- type: markdown - type: markdown
attributes: attributes:
@@ -72,7 +73,3 @@ body:
required: true required: true
- label: I have removed my token from display, if visible. - label: I have removed my token from display, if visible.
required: true required: true
- type: textarea
attributes:
label: Additional Context
description: If there is anything else to say, please do so here.

View File

@@ -1,8 +1,8 @@
blank_issues_enabled: true blank_issues_enabled: false
contact_links: contact_links:
- name: Ask a question - name: Ask a question
about: Ask questions and discuss with other users of the library. about: Ask questions and discuss with other users of the library.
url: https://github.com/iDevision/enhanced-discord.py/issues/new url: https://github.com/Rapptz/discord.py/discussions
- name: Discord Server - name: Discord Server
about: Use our official Discord server to ask for help and questions as well. about: Use our official Discord server to ask for help and questions as well.
url: https://discord.gg/TvqYBrGXEm url: https://discord.gg/r3sSKJJ

View File

@@ -1,6 +1,7 @@
name: Feature Request name: Feature Request
description: Suggest a feature for this library description: Suggest a feature for this library
labels: feature request labels: feature request
issue_body: true
body: body:
- type: input - type: input
attributes: attributes:
@@ -43,7 +44,3 @@ body:
What is the current solution to the problem, if any? What is the current solution to the problem, if any?
validations: validations:
required: false required: false
- type: textarea
attributes:
label: Additional Context
description: If there is anything else to say, please do so here.

2
.gitignore vendored
View File

@@ -14,3 +14,5 @@ docs/crowdin.py
*.jpg *.jpg
*.flac *.flac
*.mo *.mo
dist
build

View File

@@ -2,4 +2,3 @@ include README.rst
include LICENSE include LICENSE
include requirements.txt include requirements.txt
include discord/bin/*.dll include discord/bin/*.dll
include discord/py.typed

114
README.ja.rst Normal file
View File

@@ -0,0 +1,114 @@
discord.py
==========
.. image:: https://discord.com/api/guilds/336642139381301249/embed.png
:target: https://discord.gg/nXzj3dg
:alt: Discordサーバーの招待
.. image:: https://img.shields.io/pypi/v/discord.py.svg
:target: https://pypi.python.org/pypi/discord.py
:alt: PyPIのバージョン情報
.. image:: https://img.shields.io/pypi/pyversions/discord.py.svg
:target: https://pypi.python.org/pypi/discord.py
:alt: PyPIのサポートしているPythonのバージョン
discord.py は機能豊富かつモダンで使いやすい、非同期処理にも対応したDiscord用のAPIラッパーです。
主な特徴
-------------
- ``async````await`` を使ったモダンなPythonらしいAPI。
- 適切なレート制限処理
- Discord APIによってサポートされているものを100カバー。
- メモリと速度の両方を最適化。
インストール
-------------
**Python 3.5.3 以降のバージョンが必須です**
完全な音声サポートなしでライブラリをインストールする場合は次のコマンドを実行してください:
.. code:: sh
# Linux/OS X
python3 -m pip install -U discord.py
# Windows
py -3 -m pip install -U discord.py
音声サポートが必要なら、次のコマンドを実行しましょう:
.. code:: sh
# Linux/OS X
python3 -m pip install -U discord.py[voice]
# Windows
py -3 -m pip install -U discord.py[voice]
開発版をインストールしたいのならば、次の手順に従ってください:
.. code:: sh
$ git clone https://github.com/Rapptz/discord.py
$ cd discord.py
$ python3 -m pip install -U .[voice]
オプションパッケージ
~~~~~~~~~~~~~~~~~~~~~~
* PyNaCl (音声サポート用)
Linuxで音声サポートを導入するには、前述のコマンドを実行する前にお気に入りのパッケージマネージャー(例えば ``apt````dnf`` など)を使って以下のパッケージをインストールする必要があります:
* libffi-dev (システムによっては ``libffi-devel``)
* python-dev (例えばPython 3.6用の ``python3.6-dev``)
簡単な例
--------------
.. code:: py
import discord
class MyClient(discord.Client):
async def on_ready(self):
print('Logged on as', self.user)
async def on_message(self, message):
# don't respond to ourselves
if message.author == self.user:
return
if message.content == 'ping':
await message.channel.send('pong')
client = MyClient()
client.run('token')
Botの例
~~~~~~~~~~~~~
.. code:: py
import discord
from discord.ext import commands
bot = commands.Bot(command_prefix='>')
@bot.command()
async def ping(ctx):
await ctx.send('pong')
bot.run('token')
examplesディレクトリに更に多くのサンプルがあります。
リンク
------
- `ドキュメント <https://discordpy.readthedocs.io/ja/latest/index.html>`_
- `公式Discordサーバー <https://discord.gg/nXzj3dg>`_
- `Discord API <https://discord.gg/discord-api>`_

View File

@@ -1,47 +1,32 @@
discord.py Enhanced-dpy (custom discord.py)
========== =================================
.. image:: https://discord.com/api/guilds/514232441498763279/embed.png .. image:: https://img.shields.io/pypi/pyversions/discord.py.svg
:target: https://discord.gg/PYAfZzpsjG :target: https://pypi.python.org/pypi/discord.py
:alt: Discord server invite
.. image:: https://img.shields.io/pypi/v/enhanced-dpy.svg
:target: https://pypi.python.org/pypi/enhanced-dpy
:alt: PyPI version info
.. image:: https://img.shields.io/pypi/pyversions/enhanced-dpy.svg
:target: https://pypi.python.org/pypi/enhanced-dpy
:alt: PyPI supported Python versions :alt: PyPI supported Python versions
A modern, maintained, easy to use, feature-rich, and async ready API wrapper for Discord written in Python. A modern, easy to use, feature-rich, and async ready API wrapper for Discord written in Python.
Credits to the `original lib by Rapptz <https://github.com/Rapptz/discord.py>`_
The Future of enhanced-discord.py **WARNING: This is not the official discord.py library! As of 8/27/2021 Danny (Rapptz) has stopped development due to breaking changes. You are still able to read the official library at https://github.com/Rapptz/discord.py!**
--------------------------
Enhanced discord.py is a fork of Rapptz's discord.py, that went unmaintained (`gist <https://gist.github.com/Rapptz/4a2f62751b9600a31a0d3c78100287f1>`_) Custom Features
----------------
It is currently maintained by (in alphabetical order) **Moved to:** `Custom Features <https://enhanced-dpy.readthedocs.io/en/latest/custom_features.html>`_
- Chillymosh#8175
- Daggy#9889
- dank Had0cK#6081
- Dutchy#6127
- Eyesofcreeper#0001
- Gnome!#6669
- IAmTomahawkx#1000
- Jadon#2494
An overview of added features is available on the `custom features page <https://enhanced-dpy.readthedocs.io/en/latest/index.html#custom-features>`_.
Key Features Key Features
------------- -------------
- Modern Pythonic API using ``async`` and ``await``. - Modern Pythonic API using ``async`` and ``await``.
- Proper rate limit handling. - Proper rate limit handling.
- 100% coverage of the supported Discord API.
- Optimised in both speed and memory. - Optimised in both speed and memory.
Installing Installing
---------- ----------
**Python 3.8 or higher is required** **Python 3.5.3 or higher is required**
To install the library without full voice support, you can just run the following command: To install the library without full voice support, you can just run the following command:
@@ -53,20 +38,19 @@ To install the library without full voice support, you can just run the followin
# Windows # Windows
py -3 -m pip install -U enhanced-dpy py -3 -m pip install -U enhanced-dpy
To install the development version, do the following: To install the development version, do the following:
.. code:: sh .. code:: sh
$ git clone https://github.com/iDevision/enhanced-discord.py $ git clone https://github.com/iDevision/enhanced-discord.py
$ cd discord.py $ cd enhanced-discord.py
$ python3 -m pip install -U .[voice] $ python3 -m pip install -U .[voice]
Optional Packages Optional Packages
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
* `PyNaCl <https://pypi.org/project/PyNaCl/>`__ (for voice support) * PyNaCl (for voice support)
Please note that on Linux installing voice you must install the following packages via your favourite package manager (e.g. ``apt``, ``dnf``, etc) before running the above commands: Please note that on Linux installing voice you must install the following packages via your favourite package manager (e.g. ``apt``, ``dnf``, etc) before running the above commands:
@@ -117,5 +101,5 @@ Links
------ ------
- `Documentation <https://enhanced-dpy.readthedocs.io/en/latest/index.html>`_ - `Documentation <https://enhanced-dpy.readthedocs.io/en/latest/index.html>`_
- `Official Discord Server <https://discord.gg/PYAfZzpsjG>`_ - `Official Discord Server <https://discord.gg/wZSH7pz>`_
- `Discord API <https://discord.gg/discord-api>`_ - `Discord API <https://discord.gg/discord-api>`_

View File

@@ -13,12 +13,12 @@ __title__ = 'discord'
__author__ = 'Rapptz' __author__ = 'Rapptz'
__license__ = 'MIT' __license__ = 'MIT'
__copyright__ = 'Copyright 2015-present Rapptz' __copyright__ = 'Copyright 2015-present Rapptz'
__version__ = '2.0.0a' __version__ = '1.7.3.7.post1'
__path__ = __import__('pkgutil').extend_path(__path__, __name__) __path__ = __import__('pkgutil').extend_path(__path__, __name__)
from collections import namedtuple
import logging import logging
from typing import NamedTuple, Literal
from .client import * from .client import *
from .appinfo import * from .appinfo import *
@@ -43,7 +43,7 @@ from .template import *
from .widget import * from .widget import *
from .object import * from .object import *
from .reaction import * from .reaction import *
from . import utils, opus, abc, ui from . import utils, opus, abc
from .enums import * from .enums import *
from .embeds import * from .embeds import *
from .mentions import * from .mentions import *
@@ -55,20 +55,10 @@ from .audit_logs import *
from .raw_models import * from .raw_models import *
from .team import * from .team import *
from .sticker import * from .sticker import *
from .stage_instance import *
from .interactions import * from .interactions import *
from .components import *
from .threads import *
VersionInfo = namedtuple('VersionInfo', 'major minor micro enhanced releaselevel serial')
class VersionInfo(NamedTuple): version_info = VersionInfo(major=1, minor=7, micro=3, enhanced=7, releaselevel='final', serial=0)
major: int
minor: int
micro: int
releaselevel: Literal["alpha", "beta", "candidate", "final"]
serial: int
version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0)
logging.getLogger(__name__).addHandler(logging.NullHandler()) logging.getLogger(__name__).addHandler(logging.NullHandler())

View File

@@ -51,7 +51,7 @@ def core(parser, args):
if args.version: if args.version:
show_version() show_version()
_bot_template = """#!/usr/bin/env python3 bot_template = """#!/usr/bin/env python3
from discord.ext import commands from discord.ext import commands
import discord import discord
@@ -64,10 +64,10 @@ class Bot(commands.{base}):
try: try:
self.load_extension(cog) self.load_extension(cog)
except Exception as exc: except Exception as exc:
print(f'Could not load extension {{cog}} due to {{exc.__class__.__name__}}: {{exc}}') print('Could not load extension {{0}} due to {{1.__class__.__name__}}: {{1}}'.format(cog, exc))
async def on_ready(self): async def on_ready(self):
print(f'Logged on as {{self.user}} (ID: {{self.user.id}})') print('Logged on as {{0}} (ID: {{0.id}})'.format(self.user))
bot = Bot() bot = Bot()
@@ -77,7 +77,7 @@ bot = Bot()
bot.run(config.token) bot.run(config.token)
""" """
_gitignore_template = """# Byte-compiled / optimized / DLL files gitignore_template = """# Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
@@ -107,7 +107,7 @@ var/
config.py config.py
""" """
_cog_template = '''from discord.ext import commands cog_template = '''from discord.ext import commands
import discord import discord
class {name}(commands.Cog{attrs}): class {name}(commands.Cog{attrs}):
@@ -120,7 +120,7 @@ def setup(bot):
bot.add_cog({name}(bot)) bot.add_cog({name}(bot))
''' '''
_cog_extras = ''' cog_extras = '''
def cog_unload(self): def cog_unload(self):
# clean up logic goes here # clean up logic goes here
pass pass
@@ -170,7 +170,7 @@ _base_table = {
# NUL (0) and 1-31 are disallowed # NUL (0) and 1-31 are disallowed
_base_table.update((chr(i), None) for i in range(32)) _base_table.update((chr(i), None) for i in range(32))
_translation_table = str.maketrans(_base_table) translation_table = str.maketrans(_base_table)
def to_path(parser, name, *, replace_spaces=False): def to_path(parser, name, *, replace_spaces=False):
if isinstance(name, Path): if isinstance(name, Path):
@@ -182,7 +182,7 @@ def to_path(parser, name, *, replace_spaces=False):
if len(name) <= 4 and name.upper() in forbidden: if len(name) <= 4 and name.upper() in forbidden:
parser.error('invalid directory name given, use a different one') parser.error('invalid directory name given, use a different one')
name = name.translate(_translation_table) name = name.translate(translation_table)
if replace_spaces: if replace_spaces:
name = name.replace(' ', '-') name = name.replace(' ', '-')
return Path(name) return Path(name)
@@ -215,14 +215,14 @@ def newbot(parser, args):
try: try:
with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp: with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp:
base = 'Bot' if not args.sharded else 'AutoShardedBot' base = 'Bot' if not args.sharded else 'AutoShardedBot'
fp.write(_bot_template.format(base=base, prefix=args.prefix)) fp.write(bot_template.format(base=base, prefix=args.prefix))
except OSError as exc: except OSError as exc:
parser.error(f'could not create bot file ({exc})') parser.error(f'could not create bot file ({exc})')
if not args.no_git: if not args.no_git:
try: try:
with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp: with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp:
fp.write(_gitignore_template) fp.write(gitignore_template)
except OSError as exc: except OSError as exc:
print(f'warning: could not create .gitignore file ({exc})') print(f'warning: could not create .gitignore file ({exc})')
@@ -240,7 +240,7 @@ def newcog(parser, args):
try: try:
with open(str(directory), 'w', encoding='utf-8') as fp: with open(str(directory), 'w', encoding='utf-8') as fp:
attrs = '' attrs = ''
extra = _cog_extras if args.full else '' extra = cog_extras if args.full else ''
if args.class_name: if args.class_name:
name = args.class_name name = args.class_name
else: else:
@@ -255,7 +255,7 @@ def newcog(parser, args):
attrs += f', name="{args.display_name}"' attrs += f', name="{args.display_name}"'
if args.hide_commands: if args.hide_commands:
attrs += ', command_attrs=dict(hidden=True)' attrs += ', command_attrs=dict(hidden=True)'
fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs)) fp.write(cog_template.format(name=name, extra=extra, attrs=attrs))
except OSError as exc: except OSError as exc:
parser.error(f'could not create cog file ({exc})') parser.error(f'could not create cog file ({exc})')
else: else:

File diff suppressed because it is too large Load Diff

View File

@@ -22,10 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import datetime import datetime
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, overload
from .asset import Asset from .asset import Asset
from .enums import ActivityType, try_enum from .enums import ActivityType, try_enum
@@ -74,9 +71,6 @@ type: int
sync_id: str sync_id: str
session_id: str session_id: str
flags: int flags: int
buttons: list[dict]
label: str (max: 32)
url: str (max: 512)
There are also activity flags which are mostly uninteresting for the library atm. There are also activity flags which are mostly uninteresting for the library atm.
@@ -90,16 +84,6 @@ t.ActivityFlags = {
} }
""" """
if TYPE_CHECKING:
from .types.activity import (
Activity as ActivityPayload,
ActivityTimestamps,
ActivityParty,
ActivityAssets,
ActivityButton,
)
class BaseActivity: class BaseActivity:
"""The base activity that all user-settable activities inherit from. """The base activity that all user-settable activities inherit from.
A user-settable activity is one that can be used in :meth:`Client.change_presence`. A user-settable activity is one that can be used in :meth:`Client.change_presence`.
@@ -118,24 +102,19 @@ class BaseActivity:
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
__slots__ = ('_created_at',) __slots__ = ('_created_at',)
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._created_at: Optional[float] = kwargs.pop('created_at', None) self._created_at = kwargs.pop('created_at', None)
@property @property
def created_at(self) -> Optional[datetime.datetime]: def created_at(self):
"""Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC. """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC.
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
if self._created_at is not None: if self._created_at is not None:
return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) return datetime.datetime.utcfromtimestamp(self._created_at / 1000)
def to_dict(self) -> ActivityPayload:
raise NotImplementedError
class Activity(BaseActivity): class Activity(BaseActivity):
"""Represents an activity in Discord. """Represents an activity in Discord.
@@ -151,17 +130,17 @@ class Activity(BaseActivity):
Attributes Attributes
------------ ------------
application_id: Optional[:class:`int`] application_id: :class:`int`
The application ID of the game. The application ID of the game.
name: Optional[:class:`str`] name: :class:`str`
The name of the activity. The name of the activity.
url: Optional[:class:`str`] url: :class:`str`
A stream URL that the activity could be doing. A stream URL that the activity could be doing.
type: :class:`ActivityType` type: :class:`ActivityType`
The type of activity currently being done. The type of activity currently being done.
state: Optional[:class:`str`] state: :class:`str`
The user's current state. For example, "In Game". The user's current state. For example, "In Game".
details: Optional[:class:`str`] details: :class:`str`
The detail of the user's current activity. The detail of the user's current activity.
timestamps: :class:`dict` timestamps: :class:`dict`
A dictionary of timestamps. It contains the following optional keys: A dictionary of timestamps. It contains the following optional keys:
@@ -185,61 +164,38 @@ class Activity(BaseActivity):
- ``id``: A string representing the party ID. - ``id``: A string representing the party ID.
- ``size``: A list of up to two integer elements denoting (current_size, maximum_size). - ``size``: A list of up to two integer elements denoting (current_size, maximum_size).
buttons: List[:class:`dict`]
An list of dictionaries representing custom buttons shown in a rich presence.
Each dictionary contains the following keys:
- ``label``: A string representing the text shown on the button.
- ``url``: A string representing the URL opened upon clicking the button.
.. versionadded:: 2.0
emoji: Optional[:class:`PartialEmoji`] emoji: Optional[:class:`PartialEmoji`]
The emoji that belongs to this activity. The emoji that belongs to this activity.
""" """
__slots__ = ( __slots__ = ('state', 'details', '_created_at', 'timestamps', 'assets', 'party',
'state', 'flags', 'sync_id', 'session_id', 'type', 'name', 'url',
'details', 'application_id', 'emoji')
'_created_at',
'timestamps',
'assets',
'party',
'flags',
'sync_id',
'session_id',
'type',
'name',
'url',
'application_id',
'emoji',
'buttons',
)
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.state: Optional[str] = kwargs.pop('state', None) self.state = kwargs.pop('state', None)
self.details: Optional[str] = kwargs.pop('details', None) self.details = kwargs.pop('details', None)
self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {}) self.timestamps = kwargs.pop('timestamps', {})
self.assets: ActivityAssets = kwargs.pop('assets', {}) self.assets = kwargs.pop('assets', {})
self.party: ActivityParty = kwargs.pop('party', {}) self.party = kwargs.pop('party', {})
self.application_id: Optional[int] = _get_as_snowflake(kwargs, 'application_id') self.application_id = _get_as_snowflake(kwargs, 'application_id')
self.name: Optional[str] = kwargs.pop('name', None) self.name = kwargs.pop('name', None)
self.url: Optional[str] = kwargs.pop('url', None) self.url = kwargs.pop('url', None)
self.flags: int = kwargs.pop('flags', 0) self.flags = kwargs.pop('flags', 0)
self.sync_id: Optional[str] = kwargs.pop('sync_id', None) self.sync_id = kwargs.pop('sync_id', None)
self.session_id: Optional[str] = kwargs.pop('session_id', None) self.session_id = kwargs.pop('session_id', None)
self.buttons: List[ActivityButton] = kwargs.pop('buttons', [])
activity_type = kwargs.pop('type', -1) activity_type = kwargs.pop('type', -1)
self.type: ActivityType = ( self.type = activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type)
activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type)
)
emoji = kwargs.pop('emoji', None) emoji = kwargs.pop('emoji', None)
self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None if emoji is not None:
self.emoji = PartialEmoji.from_dict(emoji)
else:
self.emoji = None
def __repr__(self) -> str: def __repr__(self):
attrs = ( attrs = (
('type', self.type), ('type', self.type),
('name', self.name), ('name', self.name),
@@ -252,8 +208,8 @@ class Activity(BaseActivity):
inner = ' '.join('%s=%r' % t for t in attrs) inner = ' '.join('%s=%r' % t for t in attrs)
return f'<Activity {inner}>' return f'<Activity {inner}>'
def to_dict(self) -> Dict[str, Any]: def to_dict(self):
ret: Dict[str, Any] = {} ret = {}
for attr in self.__slots__: for attr in self.__slots__:
value = getattr(self, attr, None) value = getattr(self, attr, None)
if value is None: if value is None:
@@ -269,27 +225,27 @@ class Activity(BaseActivity):
return ret return ret
@property @property
def start(self) -> Optional[datetime.datetime]: def start(self):
"""Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable.""" """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable."""
try: try:
timestamp = self.timestamps['start'] / 1000 timestamp = self.timestamps['start'] / 1000
except KeyError: except KeyError:
return None return None
else: else:
return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=datetime.timezone.utc)
@property @property
def end(self) -> Optional[datetime.datetime]: def end(self):
"""Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable.""" """Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable."""
try: try:
timestamp = self.timestamps['end'] / 1000 timestamp = self.timestamps['end'] / 1000
except KeyError: except KeyError:
return None return None
else: else:
return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=datetime.timezone.utc)
@property @property
def large_image_url(self) -> Optional[str]: def large_image_url(self):
"""Optional[:class:`str`]: Returns a URL pointing to the large image asset of this activity if applicable.""" """Optional[:class:`str`]: Returns a URL pointing to the large image asset of this activity if applicable."""
if self.application_id is None: if self.application_id is None:
return None return None
@@ -302,7 +258,7 @@ class Activity(BaseActivity):
return Asset.BASE + f'/app-assets/{self.application_id}/{large_image}.png' return Asset.BASE + f'/app-assets/{self.application_id}/{large_image}.png'
@property @property
def small_image_url(self) -> Optional[str]: def small_image_url(self):
"""Optional[:class:`str`]: Returns a URL pointing to the small image asset of this activity if applicable.""" """Optional[:class:`str`]: Returns a URL pointing to the small image asset of this activity if applicable."""
if self.application_id is None: if self.application_id is None:
return None return None
@@ -313,14 +269,13 @@ class Activity(BaseActivity):
return None return None
else: else:
return Asset.BASE + f'/app-assets/{self.application_id}/{small_image}.png' return Asset.BASE + f'/app-assets/{self.application_id}/{small_image}.png'
@property @property
def large_image_text(self) -> Optional[str]: def large_image_text(self):
"""Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable.""" """Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable."""
return self.assets.get('large_text', None) return self.assets.get('large_text', None)
@property @property
def small_image_text(self) -> Optional[str]: def small_image_text(self):
"""Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable.""" """Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable."""
return self.assets.get('small_text', None) return self.assets.get('small_text', None)
@@ -361,12 +316,12 @@ class Game(BaseActivity):
__slots__ = ('name', '_end', '_start') __slots__ = ('name', '_end', '_start')
def __init__(self, name: str, **extra): def __init__(self, name, **extra):
super().__init__(**extra) super().__init__(**extra)
self.name: str = name self.name = name
try: try:
timestamps: ActivityTimestamps = extra['timestamps'] timestamps = extra['timestamps']
except KeyError: except KeyError:
self._start = 0 self._start = 0
self._end = 0 self._end = 0
@@ -375,7 +330,7 @@ class Game(BaseActivity):
self._end = timestamps.get('end', 0) self._end = timestamps.get('end', 0)
@property @property
def type(self) -> ActivityType: def type(self):
""":class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`. """:class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`.
It always returns :attr:`ActivityType.playing`. It always returns :attr:`ActivityType.playing`.
@@ -383,51 +338,48 @@ class Game(BaseActivity):
return ActivityType.playing return ActivityType.playing
@property @property
def start(self) -> Optional[datetime.datetime]: def start(self):
"""Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable.""" """Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable."""
if self._start: if self._start:
return datetime.datetime.fromtimestamp(self._start / 1000, tz=datetime.timezone.utc) return datetime.datetime.utcfromtimestamp(self._start / 1000).replace(tzinfo=datetime.timezone.utc)
return None return None
@property @property
def end(self) -> Optional[datetime.datetime]: def end(self):
"""Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable.""" """Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable."""
if self._end: if self._end:
return datetime.datetime.fromtimestamp(self._end / 1000, tz=datetime.timezone.utc) return datetime.datetime.utcfromtimestamp(self._end / 1000).replace(tzinfo=datetime.timezone.utc)
return None return None
def __str__(self) -> str: def __str__(self):
return str(self.name) return str(self.name)
def __repr__(self) -> str: def __repr__(self):
return f'<Game name={self.name!r}>' return f'<Game name={self.name!r}>'
def to_dict(self) -> Dict[str, Any]: def to_dict(self):
timestamps: Dict[str, Any] = {} timestamps = {}
if self._start: if self._start:
timestamps['start'] = self._start timestamps['start'] = self._start
if self._end: if self._end:
timestamps['end'] = self._end timestamps['end'] = self._end
# fmt: off
return { return {
'type': ActivityType.playing.value, 'type': ActivityType.playing.value,
'name': str(self.name), 'name': str(self.name),
'timestamps': timestamps 'timestamps': timestamps
} }
# fmt: on
def __eq__(self, other: Any) -> bool: def __eq__(self, other):
return isinstance(other, Game) and other.name == self.name return isinstance(other, Game) and other.name == self.name
def __ne__(self, other: Any) -> bool: def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self):
return hash(self.name) return hash(self.name)
class Streaming(BaseActivity): class Streaming(BaseActivity):
"""A slimmed down version of :class:`Activity` that represents a Discord streaming status. """A slimmed down version of :class:`Activity` that represents a Discord streaming status.
@@ -453,7 +405,7 @@ class Streaming(BaseActivity):
Attributes Attributes
----------- -----------
platform: Optional[:class:`str`] platform: :class:`str`
Where the user is streaming from (ie. YouTube, Twitch). Where the user is streaming from (ie. YouTube, Twitch).
.. versionadded:: 1.3 .. versionadded:: 1.3
@@ -475,27 +427,27 @@ class Streaming(BaseActivity):
__slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets') __slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets')
def __init__(self, *, name: Optional[str], url: str, **extra: Any): def __init__(self, *, name, url, **extra):
super().__init__(**extra) super().__init__(**extra)
self.platform: Optional[str] = name self.platform = name
self.name: Optional[str] = extra.pop('details', name) self.name = extra.pop('details', name)
self.game: Optional[str] = extra.pop('state', None) self.game = extra.pop('state', None)
self.url: str = url self.url = url
self.details: Optional[str] = extra.pop('details', self.name) # compatibility self.details = extra.pop('details', self.name) # compatibility
self.assets: ActivityAssets = extra.pop('assets', {}) self.assets = extra.pop('assets', {})
@property @property
def type(self) -> ActivityType: def type(self):
""":class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`. """:class:`ActivityType`: Returns the game's type. This is for compatibility with :class:`Activity`.
It always returns :attr:`ActivityType.streaming`. It always returns :attr:`ActivityType.streaming`.
""" """
return ActivityType.streaming return ActivityType.streaming
def __str__(self) -> str: def __str__(self):
return str(self.name) return str(self.name)
def __repr__(self) -> str: def __repr__(self):
return f'<Streaming name={self.name!r}>' return f'<Streaming name={self.name!r}>'
@property @property
@@ -513,29 +465,26 @@ class Streaming(BaseActivity):
else: else:
return name[7:] if name[:7] == 'twitch:' else None return name[7:] if name[:7] == 'twitch:' else None
def to_dict(self) -> Dict[str, Any]: def to_dict(self):
# fmt: off ret = {
ret: Dict[str, Any] = {
'type': ActivityType.streaming.value, 'type': ActivityType.streaming.value,
'name': str(self.name), 'name': str(self.name),
'url': str(self.url), 'url': str(self.url),
'assets': self.assets 'assets': self.assets
} }
# fmt: on
if self.details: if self.details:
ret['details'] = self.details ret['details'] = self.details
return ret return ret
def __eq__(self, other: Any) -> bool: def __eq__(self, other):
return isinstance(other, Streaming) and other.name == self.name and other.url == self.url return isinstance(other, Streaming) and other.name == self.name and other.url == self.url
def __ne__(self, other: Any) -> bool: def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self):
return hash(self.name) return hash(self.name)
class Spotify: class Spotify:
"""Represents a Spotify listening activity from Discord. This is a special case of """Represents a Spotify listening activity from Discord. This is a special case of
:class:`Activity` that makes it easier to work with the Spotify integration. :class:`Activity` that makes it easier to work with the Spotify integration.
@@ -559,20 +508,21 @@ class Spotify:
Returns the string 'Spotify'. Returns the string 'Spotify'.
""" """
__slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at') __slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id',
'_created_at')
def __init__(self, **data): def __init__(self, **data):
self._state: str = data.pop('state', '') self._state = data.pop('state', None)
self._details: str = data.pop('details', '') self._details = data.pop('details', None)
self._timestamps: Dict[str, int] = data.pop('timestamps', {}) self._timestamps = data.pop('timestamps', {})
self._assets: ActivityAssets = data.pop('assets', {}) self._assets = data.pop('assets', {})
self._party: ActivityParty = data.pop('party', {}) self._party = data.pop('party', {})
self._sync_id: str = data.pop('sync_id') self._sync_id = data.pop('sync_id')
self._session_id: str = data.pop('session_id') self._session_id = data.pop('session_id')
self._created_at: Optional[float] = data.pop('created_at', None) self._created_at = data.pop('created_at', None)
@property @property
def type(self) -> ActivityType: def type(self):
""":class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`. """:class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`.
It always returns :attr:`ActivityType.listening`. It always returns :attr:`ActivityType.listening`.
@@ -580,31 +530,31 @@ class Spotify:
return ActivityType.listening return ActivityType.listening
@property @property
def created_at(self) -> Optional[datetime.datetime]: def created_at(self):
"""Optional[:class:`datetime.datetime`]: When the user started listening in UTC. """Optional[:class:`datetime.datetime`]: When the user started listening in UTC.
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
if self._created_at is not None: if self._created_at is not None:
return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) return datetime.datetime.utcfromtimestamp(self._created_at / 1000)
@property @property
def colour(self) -> Colour: def colour(self):
""":class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`. """:class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`.
There is an alias for this named :attr:`color`""" There is an alias for this named :attr:`color`"""
return Colour(0x1DB954) return Colour(0x1db954)
@property @property
def color(self) -> Colour: def color(self):
""":class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`. """:class:`Colour`: Returns the Spotify integration colour, as a :class:`Colour`.
There is an alias for this named :attr:`colour`""" There is an alias for this named :attr:`colour`"""
return self.colour return self.colour
def to_dict(self) -> Dict[str, Any]: def to_dict(self):
return { return {
'flags': 48, # SYNC | PLAY 'flags': 48, # SYNC | PLAY
'name': 'Spotify', 'name': 'Spotify',
'assets': self._assets, 'assets': self._assets,
'party': self._party, 'party': self._party,
@@ -612,46 +562,42 @@ class Spotify:
'session_id': self._session_id, 'session_id': self._session_id,
'timestamps': self._timestamps, 'timestamps': self._timestamps,
'details': self._details, 'details': self._details,
'state': self._state, 'state': self._state
} }
@property @property
def name(self) -> str: def name(self):
""":class:`str`: The activity's name. This will always return "Spotify".""" """:class:`str`: The activity's name. This will always return "Spotify"."""
return 'Spotify' return 'Spotify'
def __eq__(self, other: Any) -> bool: def __eq__(self, other):
return ( return (isinstance(other, Spotify) and other._session_id == self._session_id
isinstance(other, Spotify) and other._sync_id == self._sync_id and other.start == self.start)
and other._session_id == self._session_id
and other._sync_id == self._sync_id
and other.start == self.start
)
def __ne__(self, other: Any) -> bool: def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self):
return hash(self._session_id) return hash(self._session_id)
def __str__(self) -> str: def __str__(self):
return 'Spotify' return 'Spotify'
def __repr__(self) -> str: def __repr__(self):
return f'<Spotify title={self.title!r} artist={self.artist!r} track_id={self.track_id!r}>' return '<Spotify title={0.title!r} artist={0.artist!r} track_id={0.track_id!r}>'.format(self)
@property @property
def title(self) -> str: def title(self):
""":class:`str`: The title of the song being played.""" """:class:`str`: The title of the song being played."""
return self._details return self._details
@property @property
def artists(self) -> List[str]: def artists(self):
"""List[:class:`str`]: The artists of the song being played.""" """List[:class:`str`]: The artists of the song being played."""
return self._state.split('; ') return self._state.split('; ')
@property @property
def artist(self) -> str: def artist(self):
""":class:`str`: The artist of the song being played. """:class:`str`: The artist of the song being played.
This does not attempt to split the artist information into This does not attempt to split the artist information into
@@ -660,12 +606,12 @@ class Spotify:
return self._state return self._state
@property @property
def album(self) -> str: def album(self):
""":class:`str`: The album that the song being played belongs to.""" """:class:`str`: The album that the song being played belongs to."""
return self._assets.get('large_text', '') return self._assets.get('large_text', '')
@property @property
def album_cover_url(self) -> str: def album_cover_url(self):
""":class:`str`: The album cover image URL from Spotify's CDN.""" """:class:`str`: The album cover image URL from Spotify's CDN."""
large_image = self._assets.get('large_image', '') large_image = self._assets.get('large_image', '')
if large_image[:8] != 'spotify:': if large_image[:8] != 'spotify:':
@@ -674,39 +620,30 @@ class Spotify:
return 'https://i.scdn.co/image/' + album_image_id return 'https://i.scdn.co/image/' + album_image_id
@property @property
def track_id(self) -> str: def track_id(self):
""":class:`str`: The track ID used by Spotify to identify this song.""" """:class:`str`: The track ID used by Spotify to identify this song."""
return self._sync_id return self._sync_id
@property @property
def track_url(self) -> str: def start(self):
""":class:`str`: The track URL to listen on Spotify.
.. versionadded:: 2.0
"""
return f'https://open.spotify.com/track/{self.track_id}'
@property
def start(self) -> datetime.datetime:
""":class:`datetime.datetime`: When the user started playing this song in UTC.""" """:class:`datetime.datetime`: When the user started playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) return datetime.datetime.utcfromtimestamp(self._timestamps['start'] / 1000)
@property @property
def end(self) -> datetime.datetime: def end(self):
""":class:`datetime.datetime`: When the user will stop playing this song in UTC.""" """:class:`datetime.datetime`: When the user will stop playing this song in UTC."""
return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) return datetime.datetime.utcfromtimestamp(self._timestamps['end'] / 1000)
@property @property
def duration(self) -> datetime.timedelta: def duration(self):
""":class:`datetime.timedelta`: The duration of the song being played.""" """:class:`datetime.timedelta`: The duration of the song being played."""
return self.end - self.start return self.end - self.start
@property @property
def party_id(self) -> str: def party_id(self):
""":class:`str`: The party ID of the listening party.""" """:class:`str`: The party ID of the listening party."""
return self._party.get('id', '') return self._party.get('id', '')
class CustomActivity(BaseActivity): class CustomActivity(BaseActivity):
"""Represents a Custom activity from Discord. """Represents a Custom activity from Discord.
@@ -740,14 +677,13 @@ class CustomActivity(BaseActivity):
__slots__ = ('name', 'emoji', 'state') __slots__ = ('name', 'emoji', 'state')
def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any): def __init__(self, name, *, emoji=None, **extra):
super().__init__(**extra) super().__init__(**extra)
self.name: Optional[str] = name self.name = name
self.state: Optional[str] = extra.pop('state', None) self.state = extra.pop('state', None)
if self.name == 'Custom Status': if self.name == 'Custom Status':
self.name = self.state self.name = self.state
self.emoji: Optional[PartialEmoji]
if emoji is None: if emoji is None:
self.emoji = emoji self.emoji = emoji
elif isinstance(emoji, dict): elif isinstance(emoji, dict):
@@ -760,14 +696,14 @@ class CustomActivity(BaseActivity):
raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.') raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.')
@property @property
def type(self) -> ActivityType: def type(self):
""":class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`. """:class:`ActivityType`: Returns the activity's type. This is for compatibility with :class:`Activity`.
It always returns :attr:`ActivityType.custom`. It always returns :attr:`ActivityType.custom`.
""" """
return ActivityType.custom return ActivityType.custom
def to_dict(self) -> Dict[str, Any]: def to_dict(self):
if self.name == self.state: if self.name == self.state:
o = { o = {
'type': ActivityType.custom.value, 'type': ActivityType.custom.value,
@@ -784,16 +720,16 @@ class CustomActivity(BaseActivity):
o['emoji'] = self.emoji.to_dict() o['emoji'] = self.emoji.to_dict()
return o return o
def __eq__(self, other: Any) -> bool: def __eq__(self, other):
return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji return (isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji)
def __ne__(self, other: Any) -> bool: def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self):
return hash((self.name, str(self.emoji))) return hash((self.name, str(self.emoji)))
def __str__(self) -> str: def __str__(self):
if self.emoji: if self.emoji:
if self.name: if self.name:
return f'{self.emoji} {self.name}' return f'{self.emoji} {self.name}'
@@ -801,21 +737,11 @@ class CustomActivity(BaseActivity):
else: else:
return str(self.name) return str(self.name)
def __repr__(self) -> str: def __repr__(self):
return f'<CustomActivity name={self.name!r} emoji={self.emoji!r}>' return '<CustomActivity name={0.name!r} emoji={0.emoji!r}>'.format(self)
ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify] def create_activity(data):
@overload
def create_activity(data: ActivityPayload) -> ActivityTypes:
...
@overload
def create_activity(data: None) -> None:
...
def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
if not data: if not data:
return None return None
@@ -830,12 +756,10 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]:
except KeyError: except KeyError:
return Activity(**data) return Activity(**data)
else: else:
# we removed the name key from data already return CustomActivity(name=name, **data)
return CustomActivity(name=name, **data) # type: ignore
elif game_type is ActivityType.streaming: elif game_type is ActivityType.streaming:
if 'url' in data: if 'url' in data:
# the url won't be None here return Streaming(**data)
return Streaming(**data) # type: ignore
return Activity(**data) return Activity(**data)
elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data: elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data:
return Spotify(**data) return Spotify(**data)

View File

@@ -22,29 +22,15 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import List, TYPE_CHECKING, Optional
from . import utils from . import utils
from .user import User
from .asset import Asset from .asset import Asset
from .team import Team
if TYPE_CHECKING:
from .guild import Guild
from .types.appinfo import (
AppInfo as AppInfoPayload,
PartialAppInfo as PartialAppInfoPayload,
Team as TeamPayload,
)
from .user import User
from .state import ConnectionState
__all__ = ( __all__ = (
'AppInfo', 'AppInfo',
'PartialAppInfo',
) )
class AppInfo: class AppInfo:
"""Represents the application info for the bot provided by Discord. """Represents the application info for the bot provided by Discord.
@@ -62,7 +48,9 @@ class AppInfo:
.. versionadded:: 1.3 .. versionadded:: 1.3
description: :class:`str` icon: Optional[:class:`str`]
The icon hash, if it exists.
description: Optional[:class:`str`]
The application description. The application description.
bot_public: :class:`bool` bot_public: :class:`bool`
Whether the bot can be invited by anyone or if it is locked Whether the bot can be invited by anyone or if it is locked
@@ -103,145 +91,128 @@ class AppInfo:
.. versionadded:: 1.3 .. versionadded:: 1.3
terms_of_service_url: Optional[:class:`str`] cover_image: Optional[:class:`str`]
The application's terms of service URL, if set. If this application is a game sold on Discord,
this field will be the hash of the image on store embeds
.. versionadded:: 2.0 .. versionadded:: 1.3
privacy_policy_url: Optional[:class:`str`]
The application's privacy policy URL, if set.
.. versionadded:: 2.0
""" """
__slots__ = ('_state', 'description', 'id', 'name', 'rpc_origins',
'bot_public', 'bot_require_code_grant', 'owner', 'icon',
'summary', 'verify_key', 'team', 'guild_id', 'primary_sku_id',
'slug', 'cover_image')
__slots__ = ( def __init__(self, state, data):
'_state', self._state = state
'description',
'id',
'name',
'rpc_origins',
'bot_public',
'bot_require_code_grant',
'owner',
'_icon',
'summary',
'verify_key',
'team',
'guild_id',
'primary_sku_id',
'slug',
'_cover_image',
'terms_of_service_url',
'privacy_policy_url',
)
def __init__(self, state: ConnectionState, data: AppInfoPayload): self.id = int(data['id'])
from .team import Team self.name = data['name']
self.description = data['description']
self.icon = data['icon']
self.rpc_origins = data['rpc_origins']
self.bot_public = data['bot_public']
self.bot_require_code_grant = data['bot_require_code_grant']
self.owner = User(state=self._state, data=data['owner'])
self._state: ConnectionState = state team = data.get('team')
self.id: int = int(data['id']) self.team = Team(state, team) if team else None
self.name: str = data['name']
self.description: str = data['description']
self._icon: Optional[str] = data['icon']
self.rpc_origins: List[str] = data['rpc_origins']
self.bot_public: bool = data['bot_public']
self.bot_require_code_grant: bool = data['bot_require_code_grant']
self.owner: User = state.create_user(data['owner'])
team: Optional[TeamPayload] = data.get('team') self.summary = data['summary']
self.team: Optional[Team] = Team(state, team) if team else None self.verify_key = data['verify_key']
self.summary: str = data['summary'] self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.verify_key: str = data['verify_key']
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') self.primary_sku_id = utils._get_as_snowflake(data, 'primary_sku_id')
self.slug = data.get('slug')
self.cover_image = data.get('cover_image')
self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id') def __repr__(self):
self.slug: Optional[str] = data.get('slug') return '<{0.__class__.__name__} id={0.id} name={0.name!r} description={0.description!r} public={0.bot_public} ' \
self._cover_image: Optional[str] = data.get('cover_image') 'owner={0.owner!r}>'.format(self)
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
def __repr__(self) -> str:
return (
f'<{self.__class__.__name__} id={self.id} name={self.name!r} '
f'description={self.description!r} public={self.bot_public} '
f'owner={self.owner!r}>'
)
@property @property
def icon(self) -> Optional[Asset]: def icon_url(self):
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" """:class:`.Asset`: Retrieves the application's icon asset.
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='app')
@property This is equivalent to calling :meth:`icon_url_as` with
def cover_image(self) -> Optional[Asset]: the default parameters ('webp' format and a size of 1024).
"""Optional[:class:`.Asset`]: Retrieves the cover image on a store embed, if any.
This is only available if the application is a game sold on Discord. .. versionadded:: 1.3
""" """
if self._cover_image is None: return self.icon_url_as()
return None
return Asset._from_cover_image(self._state, self.id, self._cover_image) def icon_url_as(self, *, format='webp', size=1024):
"""Returns an :class:`Asset` for the icon the application has.
The format must be one of 'webp', 'jpeg', 'jpg' or 'png'.
The size must be a power of 2 between 16 and 4096.
.. versionadded:: 1.6
Parameters
-----------
format: :class:`str`
The format to attempt to convert the icon to. Defaults to 'webp'.
size: :class:`int`
The size of the image to display.
Raises
------
InvalidArgument
Bad image format passed to ``format`` or invalid ``size``.
Returns
--------
:class:`Asset`
The resulting CDN asset.
"""
return Asset._from_icon(self._state, self, 'app', format=format, size=size)
@property @property
def guild(self) -> Optional[Guild]: def cover_image_url(self):
""":class:`.Asset`: Retrieves the cover image on a store embed.
This is equivalent to calling :meth:`cover_image_url_as` with
the default parameters ('webp' format and a size of 1024).
.. versionadded:: 1.3
"""
return self.cover_image_url_as()
def cover_image_url_as(self, *, format='webp', size=1024):
"""Returns an :class:`Asset` for the image on store embeds
if this application is a game sold on Discord.
The format must be one of 'webp', 'jpeg', 'jpg' or 'png'.
The size must be a power of 2 between 16 and 4096.
.. versionadded:: 1.6
Parameters
-----------
format: :class:`str`
The format to attempt to convert the image to. Defaults to 'webp'.
size: :class:`int`
The size of the image to display.
Raises
------
InvalidArgument
Bad image format passed to ``format`` or invalid ``size``.
Returns
--------
:class:`Asset`
The resulting CDN asset.
"""
return Asset._from_cover_image(self._state, self, format=format, size=size)
@property
def guild(self):
"""Optional[:class:`Guild`]: If this application is a game sold on Discord, """Optional[:class:`Guild`]: If this application is a game sold on Discord,
this field will be the guild to which it has been linked this field will be the guild to which it has been linked
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
return self._state._get_guild(self.guild_id) return self._state._get_guild(int(self.guild_id))
class PartialAppInfo:
"""Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite`
.. versionadded:: 2.0
Attributes
-------------
id: :class:`int`
The application ID.
name: :class:`str`
The application name.
description: :class:`str`
The application description.
rpc_origins: Optional[List[:class:`str`]]
A list of RPC origin URLs, if RPC is enabled.
summary: :class:`str`
If this application is a game sold on Discord,
this field will be the summary field for the store page of its primary SKU.
verify_key: :class:`str`
The hex encoded key for verification in interactions and the
GameSDK's `GetTicket <https://discord.com/developers/docs/game-sdk/applications#getticket>`_.
terms_of_service_url: Optional[:class:`str`]
The application's terms of service URL, if set.
privacy_policy_url: Optional[:class:`str`]
The application's privacy policy URL, if set.
"""
__slots__ = ('_state', 'id', 'name', 'description', 'rpc_origins', 'summary', 'verify_key', 'terms_of_service_url', 'privacy_policy_url', '_icon')
def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload):
self._state: ConnectionState = state
self.id: int = int(data['id'])
self.name: str = data['name']
self._icon: Optional[str] = data.get('icon')
self.description: str = data['description']
self.rpc_origins: Optional[List[str]] = data.get('rpc_origins')
self.summary: str = data['summary']
self.verify_key: str = data['verify_key']
self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url')
self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url')
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>'
@property
def icon(self) -> Optional[Asset]:
"""Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any."""
if self._icon is None:
return None
return Asset._from_icon(self._state, self.id, self._icon, path='app')

View File

@@ -22,101 +22,19 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import io import io
import os
from typing import Any, Literal, Optional, TYPE_CHECKING, Tuple, Union
from .errors import DiscordException from .errors import DiscordException
from .errors import InvalidArgument from .errors import InvalidArgument
from . import utils from . import utils
import yarl
__all__ = ( __all__ = (
'Asset', 'Asset',
) )
if TYPE_CHECKING:
ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png']
ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif']
VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} VALID_AVATAR_FORMATS = VALID_STATIC_FORMATS | {"gif"}
class Asset:
MISSING = utils.MISSING
class AssetMixin:
url: str
_state: Optional[Any]
async def read(self) -> bytes:
"""|coro|
Retrieves the content of this asset as a :class:`bytes` object.
Raises
------
DiscordException
There was no internal connection state.
HTTPException
Downloading the asset failed.
NotFound
The asset was deleted.
Returns
-------
:class:`bytes`
The content of the asset.
"""
if self._state is None:
raise DiscordException('Invalid state (no ConnectionState provided)')
return await self._state.http.get_from_cdn(self.url)
async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int:
"""|coro|
Saves this asset into a file-like object.
Parameters
----------
fp: Union[:class:`io.BufferedIOBase`, :class:`os.PathLike`]
The file-like object to save this attachment to or the filename
to use. If a filename is passed then a file is created with that
filename and used instead.
seek_begin: :class:`bool`
Whether to seek to the beginning of the file after saving is
successfully done.
Raises
------
DiscordException
There was no internal connection state.
HTTPException
Downloading the asset failed.
NotFound
The asset was deleted.
Returns
--------
:class:`int`
The number of bytes written.
"""
data = await self.read()
if isinstance(fp, io.BufferedIOBase):
written = fp.write(data)
if seek_begin:
fp.seek(0)
return written
else:
with open(fp, 'wb') as f:
return f.write(data)
class Asset(AssetMixin):
"""Represents a CDN asset on Discord. """Represents a CDN asset on Discord.
.. container:: operations .. container:: operations
@@ -129,6 +47,10 @@ class Asset(AssetMixin):
Returns the length of the CDN asset's URL. Returns the length of the CDN asset's URL.
.. describe:: bool(x)
Checks if the Asset has a URL.
.. describe:: x == y .. describe:: x == y
Checks if the asset is equal to another asset. Checks if the asset is equal to another asset.
@@ -141,275 +63,202 @@ class Asset(AssetMixin):
Returns the hash of the asset. Returns the hash of the asset.
""" """
__slots__ = ('_state', '_url')
__slots__: Tuple[str, ...] = (
'_state',
'_url',
'_animated',
'_key',
)
BASE = 'https://cdn.discordapp.com' BASE = 'https://cdn.discordapp.com'
def __init__(self, state, *, url: str, key: str, animated: bool = False): def __init__(self, state, url=None):
self._state = state self._state = state
self._url = url self._url = url
self._animated = animated
self._key = key
@classmethod @classmethod
def _from_default_avatar(cls, state, index: int) -> Asset: def _from_avatar(cls, state, user, *, format=None, static_format='webp', size=1024):
return cls( if not utils.valid_icon_size(size):
state, raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url=f'{cls.BASE}/embed/avatars/{index}.png', if format is not None and format not in VALID_AVATAR_FORMATS:
key=str(index), raise InvalidArgument(f"format must be None or one of {VALID_AVATAR_FORMATS}")
animated=False, if format == "gif" and not user.is_avatar_animated():
) raise InvalidArgument("non animated avatars do not support gif format")
if static_format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f"static_format must be one of {VALID_STATIC_FORMATS}")
if user.avatar is None:
return user.default_avatar_url
if format is None:
format = 'gif' if user.is_avatar_animated() else static_format
return cls(state, '/avatars/{0.id}/{0.avatar}.{1}?size={2}'.format(user, format, size))
@classmethod @classmethod
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: def _from_icon(cls, state, object, path, *, format='webp', size=1024):
animated = avatar.startswith('a_') if object.icon is None:
format = 'gif' if animated else 'png' return cls(state)
return cls(
state, if not utils.valid_icon_size(size):
url=f'{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024', raise InvalidArgument("size must be a power of 2 between 16 and 4096")
key=avatar, if format not in VALID_STATIC_FORMATS:
animated=animated, raise InvalidArgument(f"format must be None or one of {VALID_STATIC_FORMATS}")
)
url = '/{0}-icons/{1.id}/{1.icon}.{2}?size={3}'.format(path, object, format, size)
return cls(state, url)
@classmethod @classmethod
def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: def _from_cover_image(cls, state, obj, *, format='webp', size=1024):
animated = avatar.startswith('a_') if obj.cover_image is None:
format = 'gif' if animated else 'png' return cls(state)
return cls(
state, if not utils.valid_icon_size(size):
url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024", raise InvalidArgument("size must be a power of 2 between 16 and 4096")
key=avatar, if format not in VALID_STATIC_FORMATS:
animated=animated, raise InvalidArgument(f"format must be None or one of {VALID_STATIC_FORMATS}")
)
url = '/app-assets/{0.id}/store/{0.cover_image}.{1}?size={2}'.format(obj, format, size)
return cls(state, url)
@classmethod @classmethod
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: def _from_guild_image(cls, state, id, hash, key, *, format='webp', size=1024):
return cls( if not utils.valid_icon_size(size):
state, raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024', if format not in VALID_STATIC_FORMATS:
key=icon_hash, raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}")
animated=False,
) if hash is None:
return cls(state)
url = '/{key}/{0}/{1}.{2}?size={3}'
return cls(state, url.format(id, hash, format, size, key=key))
@classmethod @classmethod
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: def _from_guild_icon(cls, state, guild, *, format=None, static_format='webp', size=1024):
return cls( if not utils.valid_icon_size(size):
state, raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024', if format is not None and format not in VALID_AVATAR_FORMATS:
key=cover_image_hash, raise InvalidArgument(f"format must be one of {VALID_AVATAR_FORMATS}")
animated=False, if format == "gif" and not guild.is_icon_animated():
) raise InvalidArgument("non animated guild icons do not support gif format")
if static_format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f"static_format must be one of {VALID_STATIC_FORMATS}")
if guild.icon is None:
return cls(state)
if format is None:
format = 'gif' if guild.is_icon_animated() else static_format
return cls(state, '/icons/{0.id}/{0.icon}.{1}?size={2}'.format(guild, format, size))
@classmethod @classmethod
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: def _from_sticker_url(cls, state, sticker, *, size=1024):
return cls( if not utils.valid_icon_size(size):
state, raise InvalidArgument("size must be a power of 2 between 16 and 4096")
url=f'{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024',
key=image, return cls(state, '/stickers/{0.id}/{0.image}.png?size={2}'.format(sticker, format, size))
animated=False,
)
@classmethod @classmethod
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: def _from_emoji(cls, state, emoji, *, format=None, static_format='png'):
animated = icon_hash.startswith('a_') if format is not None and format not in VALID_AVATAR_FORMATS:
format = 'gif' if animated else 'png' raise InvalidArgument(f"format must be None or one of {VALID_AVATAR_FORMATS}")
return cls( if format == "gif" and not emoji.animated:
state, raise InvalidArgument("non animated emoji's do not support gif format")
url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024', if static_format not in VALID_STATIC_FORMATS:
key=icon_hash, raise InvalidArgument(f"static_format must be one of {VALID_STATIC_FORMATS}")
animated=animated, if format is None:
) format = 'gif' if emoji.animated else static_format
@classmethod return cls(state, f'/emojis/{emoji.id}.{format}')
def _from_sticker_banner(cls, state, banner: int) -> Asset:
return cls(
state,
url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png',
key=str(banner),
animated=False,
)
@classmethod def __str__(self):
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: return self.BASE + self._url if self._url is not None else ''
animated = banner_hash.startswith('a_')
format = 'gif' if animated else 'png'
return cls(
state,
url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512',
key=banner_hash,
animated=animated
)
def __str__(self) -> str: def __len__(self):
return self._url if self._url:
return len(self.BASE + self._url)
return 0
def __len__(self) -> int: def __bool__(self):
return len(self._url) return self._url is not None
def __repr__(self): def __repr__(self):
shorten = self._url.replace(self.BASE, '') return f'<Asset url={self._url!r}>'
return f'<Asset url={shorten!r}>'
def __eq__(self, other): def __eq__(self, other):
return isinstance(other, Asset) and self._url == other._url return isinstance(other, Asset) and self._url == other._url
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self): def __hash__(self):
return hash(self._url) return hash(self._url)
@property async def read(self):
def url(self) -> str: """|coro|
""":class:`str`: Returns the underlying URL of the asset."""
return self._url
@property Retrieves the content of this asset as a :class:`bytes` object.
def key(self) -> str:
""":class:`str`: Returns the identifying key of the asset."""
return self._key
def is_animated(self) -> bool: .. warning::
""":class:`bool`: Returns whether the asset is animated."""
return self._animated
def replace( :class:`PartialEmoji` won't have a connection state if user created,
self, and a URL won't be present if a custom image isn't associated with
*, the asset, e.g. a guild with no custom icon.
size: int = MISSING,
format: ValidAssetFormatTypes = MISSING,
static_format: ValidStaticFormatTypes = MISSING,
) -> Asset:
"""Returns a new asset with the passed components replaced.
Parameters .. versionadded:: 1.1
-----------
size: :class:`int`
The new size of the asset.
format: :class:`str`
The new format to change it to. Must be either
'webp', 'jpeg', 'jpg', 'png', or 'gif' if it's animated.
static_format: :class:`str`
The new format to change it to if the asset isn't animated.
Must be either 'webp', 'jpeg', 'jpg', or 'png'.
Raises Raises
------
DiscordException
There was no valid URL or internal connection state.
HTTPException
Downloading the asset failed.
NotFound
The asset was deleted.
Returns
------- -------
InvalidArgument :class:`bytes`
An invalid size or format was passed. The content of the asset.
"""
if not self._url:
raise DiscordException('Invalid asset (no URL provided)')
if self._state is None:
raise DiscordException('Invalid state (no ConnectionState provided)')
return await self._state.http.get_from_cdn(self.BASE + self._url)
async def save(self, fp, *, seek_begin=True):
"""|coro|
Saves this asset into a file-like object.
Parameters
----------
fp: Union[BinaryIO, :class:`os.PathLike`]
Same as in :meth:`Attachment.save`.
seek_begin: :class:`bool`
Same as in :meth:`Attachment.save`.
Raises
------
DiscordException
There was no valid URL or internal connection state.
HTTPException
Downloading the asset failed.
NotFound
The asset was deleted.
Returns Returns
-------- --------
:class:`Asset` :class:`int`
The newly updated asset. The number of bytes written.
""" """
url = yarl.URL(self._url)
path, _ = os.path.splitext(url.path)
if format is not MISSING: data = await self.read()
if self._animated: if isinstance(fp, io.IOBase) and fp.writable():
if format not in VALID_ASSET_FORMATS: written = fp.write(data)
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}') if seek_begin:
else: fp.seek(0)
if format not in VALID_STATIC_FORMATS: return written
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{format}')
if static_format is not MISSING and not self._animated:
if static_format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'static_format must be one of {VALID_STATIC_FORMATS}')
url = url.with_path(f'{path}.{static_format}')
if size is not MISSING:
if not utils.valid_icon_size(size):
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
url = url.with_query(size=size)
else: else:
url = url.with_query(url.raw_query_string) with open(fp, 'wb') as f:
return f.write(data)
url = str(url)
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_size(self, size: int, /) -> Asset:
"""Returns a new asset with the specified size.
Parameters
------------
size: :class:`int`
The new size of the asset.
Raises
-------
InvalidArgument
The asset had an invalid size.
Returns
--------
:class:`Asset`
The new updated asset.
"""
if not utils.valid_icon_size(size):
raise InvalidArgument('size must be a power of 2 between 16 and 4096')
url = str(yarl.URL(self._url).with_query(size=size))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_format(self, format: ValidAssetFormatTypes, /) -> Asset:
"""Returns a new asset with the specified format.
Parameters
------------
format: :class:`str`
The new format of the asset.
Raises
-------
InvalidArgument
The asset had an invalid format.
Returns
--------
:class:`Asset`
The new updated asset.
"""
if self._animated:
if format not in VALID_ASSET_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}')
else:
if format not in VALID_STATIC_FORMATS:
raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}')
url = yarl.URL(self._url)
path, _ = os.path.splitext(url.path)
url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string))
return Asset(state=self._state, url=url, key=self._key, animated=self._animated)
def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset:
"""Returns a new asset with the specified static format.
This only changes the format if the underlying asset is
not animated. Otherwise, the asset is not changed.
Parameters
------------
format: :class:`str`
The new static format of the asset.
Raises
-------
InvalidArgument
The asset had an invalid format.
Returns
--------
:class:`Asset`
The new updated asset.
"""
if self._animated:
return self
return self.with_format(format)

View File

@@ -22,17 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from . import utils, enums
from .object import Object
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union from .permissions import PermissionOverwrite, Permissions
from . import enums, utils
from .asset import Asset
from .colour import Colour from .colour import Colour
from .invite import Invite from .invite import Invite
from .mixins import Hashable from .mixins import Hashable
from .object import Object
from .permissions import PermissionOverwrite, Permissions
__all__ = ( __all__ = (
'AuditLogDiff', 'AuditLogDiff',
@@ -40,72 +35,51 @@ __all__ = (
'AuditLogEntry', 'AuditLogEntry',
) )
def _transform_verification_level(entry, data):
return enums.try_enum(enums.VerificationLevel, data)
if TYPE_CHECKING: def _transform_default_notifications(entry, data):
import datetime return enums.try_enum(enums.NotificationLevel, data)
from . import abc def _transform_explicit_content_filter(entry, data):
from .emoji import Emoji return enums.try_enum(enums.ContentFilter, data)
from .guild import Guild
from .member import Member
from .role import Role
from .types.audit_log import (
AuditLogChange as AuditLogChangePayload,
AuditLogEntry as AuditLogEntryPayload,
)
from .types.channel import PermissionOverwrite as PermissionOverwritePayload
from .types.role import Role as RolePayload
from .types.snowflake import Snowflake
from .user import User
from .stage_instance import StageInstance
from .sticker import GuildSticker
from .threads import Thread
def _transform_permissions(entry, data):
return Permissions(data)
def _transform_permissions(entry: AuditLogEntry, data: str) -> Permissions: def _transform_color(entry, data):
return Permissions(int(data))
def _transform_color(entry: AuditLogEntry, data: int) -> Colour:
return Colour(data) return Colour(data)
def _transform_snowflake(entry, data):
def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int:
return int(data) return int(data)
def _transform_channel(entry, data):
def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Union[abc.GuildChannel, Object]]:
if data is None: if data is None:
return None return None
return entry.guild.get_channel(int(data)) or Object(id=data) return entry.guild.get_channel(int(data)) or Object(id=data)
def _transform_owner_id(entry, data):
def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]:
if data is None: if data is None:
return None return None
return entry._get_member(int(data)) return entry._get_member(int(data))
def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]: def _transform_inviter_id(entry, data):
if data is None: if data is None:
return None return None
return entry._state._get_guild(data) return entry._get_member(int(data))
def _transform_overwrites(entry, data):
def _transform_overwrites(
entry: AuditLogEntry, data: List[PermissionOverwritePayload]
) -> List[Tuple[Object, PermissionOverwrite]]:
overwrites = [] overwrites = []
for elem in data: for elem in data:
allow = Permissions(int(elem['allow'])) allow = Permissions(elem['allow'])
deny = Permissions(int(elem['deny'])) deny = Permissions(elem['deny'])
ow = PermissionOverwrite.from_pair(allow, deny) ow = PermissionOverwrite.from_pair(allow, deny)
ow_type = elem['type'] ow_type = elem['type']
ow_id = int(elem['id']) ow_id = int(elem['id'])
target = None if ow_type == 'role':
if ow_type == '0':
target = entry.guild.get_role(ow_id) target = entry.guild.get_role(ow_id)
elif ow_type == '1': else:
target = entry._get_member(ow_id) target = entry._get_member(ow_id)
if target is None: if target is None:
@@ -115,104 +89,41 @@ def _transform_overwrites(
return overwrites return overwrites
def _transform_icon(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]:
if data is None:
return None
return Asset._from_guild_icon(entry._state, entry.guild.id, data)
def _transform_avatar(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]:
if data is None:
return None
return Asset._from_avatar(entry._state, entry._target_id, data) # type: ignore
def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]], Optional[Asset]]:
def _transform(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]:
if data is None:
return None
return Asset._from_guild_image(entry._state, entry.guild.id, data, path=path)
return _transform
T = TypeVar('T', bound=enums.Enum)
def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]:
def _transform(entry: AuditLogEntry, data: int) -> T:
return enums.try_enum(enum, data)
return _transform
def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]:
if entry.action.name.startswith('sticker_'):
return enums.try_enum(enums.StickerType, data)
else:
return enums.try_enum(enums.ChannelType, data)
class AuditLogDiff: class AuditLogDiff:
def __len__(self) -> int: def __len__(self):
return len(self.__dict__) return len(self.__dict__)
def __iter__(self) -> Generator[Tuple[str, Any], None, None]: def __iter__(self):
yield from self.__dict__.items() return iter(self.__dict__.items())
def __repr__(self) -> str: def __repr__(self):
values = ' '.join('%s=%r' % item for item in self.__dict__.items()) values = ' '.join('%s=%r' % item for item in self.__dict__.items())
return f'<AuditLogDiff {values}>' return f'<AuditLogDiff {values}>'
if TYPE_CHECKING:
def __getattr__(self, item: str) -> Any:
...
def __setattr__(self, key: str, value: Any) -> Any:
...
Transformer = Callable[["AuditLogEntry", Any], Any]
class AuditLogChanges: class AuditLogChanges:
# fmt: off TRANSFORMERS = {
TRANSFORMERS: ClassVar[Dict[str, Tuple[Optional[str], Optional[Transformer]]]] = { 'verification_level': (None, _transform_verification_level),
'verification_level': (None, _enum_transformer(enums.VerificationLevel)), 'explicit_content_filter': (None, _transform_explicit_content_filter),
'explicit_content_filter': (None, _enum_transformer(enums.ContentFilter)),
'allow': (None, _transform_permissions), 'allow': (None, _transform_permissions),
'deny': (None, _transform_permissions), 'deny': (None, _transform_permissions),
'permissions': (None, _transform_permissions), 'permissions': (None, _transform_permissions),
'id': (None, _transform_snowflake), 'id': (None, _transform_snowflake),
'color': ('colour', _transform_color), 'color': ('colour', _transform_color),
'owner_id': ('owner', _transform_member_id), 'owner_id': ('owner', _transform_owner_id),
'inviter_id': ('inviter', _transform_member_id), 'inviter_id': ('inviter', _transform_inviter_id),
'channel_id': ('channel', _transform_channel), 'channel_id': ('channel', _transform_channel),
'afk_channel_id': ('afk_channel', _transform_channel), 'afk_channel_id': ('afk_channel', _transform_channel),
'system_channel_id': ('system_channel', _transform_channel), 'system_channel_id': ('system_channel', _transform_channel),
'widget_channel_id': ('widget_channel', _transform_channel), 'widget_channel_id': ('widget_channel', _transform_channel),
'rules_channel_id': ('rules_channel', _transform_channel),
'public_updates_channel_id': ('public_updates_channel', _transform_channel),
'permission_overwrites': ('overwrites', _transform_overwrites), 'permission_overwrites': ('overwrites', _transform_overwrites),
'splash_hash': ('splash', _guild_hash_transformer('splashes')), 'splash_hash': ('splash', None),
'banner_hash': ('banner', _guild_hash_transformer('banners')), 'icon_hash': ('icon', None),
'discovery_splash_hash': ('discovery_splash', _guild_hash_transformer('discovery-splashes')), 'avatar_hash': ('avatar', None),
'icon_hash': ('icon', _transform_icon),
'avatar_hash': ('avatar', _transform_avatar),
'rate_limit_per_user': ('slowmode_delay', None), 'rate_limit_per_user': ('slowmode_delay', None),
'guild_id': ('guild', _transform_guild_id), 'default_message_notifications': ('default_notifications', _transform_default_notifications),
'tags': ('emoji', None),
'default_message_notifications': ('default_notifications', _enum_transformer(enums.NotificationLevel)),
'region': (None, _enum_transformer(enums.VoiceRegion)),
'rtc_region': (None, _enum_transformer(enums.VoiceRegion)),
'video_quality_mode': (None, _enum_transformer(enums.VideoQualityMode)),
'privacy_level': (None, _enum_transformer(enums.StagePrivacyLevel)),
'format_type': (None, _enum_transformer(enums.StickerFormatType)),
'type': (None, _transform_type),
} }
# fmt: on
def __init__(self, entry: AuditLogEntry, data: List[AuditLogChangePayload]): def __init__(self, entry, data):
self.before = AuditLogDiff() self.before = AuditLogDiff()
self.after = AuditLogDiff() self.after = AuditLogDiff()
@@ -221,22 +132,18 @@ class AuditLogChanges:
# special cases for role add/remove # special cases for role add/remove
if attr == '$add': if attr == '$add':
self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore self._handle_role(self.before, self.after, entry, elem['new_value'])
continue continue
elif attr == '$remove': elif attr == '$remove':
self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore self._handle_role(self.after, self.before, entry, elem['new_value'])
continue continue
try: transformer = self.TRANSFORMERS.get(attr)
key, transformer = self.TRANSFORMERS[attr] if transformer:
except (ValueError, KeyError): key, transformer = transformer
transformer = None
else:
if key: if key:
attr = key attr = key
transformer: Optional[Transformer]
try: try:
before = elem['old_value'] before = elem['old_value']
except KeyError: except KeyError:
@@ -261,19 +168,16 @@ class AuditLogChanges:
if hasattr(self.after, 'colour'): if hasattr(self.after, 'colour'):
self.after.color = self.after.colour self.after.color = self.after.colour
self.before.color = self.before.colour self.before.color = self.before.colour
if hasattr(self.after, 'expire_behavior'):
self.after.expire_behaviour = self.after.expire_behavior
self.before.expire_behaviour = self.before.expire_behavior
def __repr__(self) -> str: def __repr__(self):
return f'<AuditLogChanges before={self.before!r} after={self.after!r}>' return f'<AuditLogChanges before={self.before!r} after={self.after!r}>'
def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None: def _handle_role(self, first, second, entry, elem):
if not hasattr(first, 'roles'): if not hasattr(first, 'roles'):
setattr(first, 'roles', []) setattr(first, 'roles', [])
data = [] data = []
g: Guild = entry.guild # type: ignore g = entry.guild
for e in elem: for e in elem:
role_id = int(e['id']) role_id = int(e['id'])
@@ -281,36 +185,12 @@ class AuditLogChanges:
if role is None: if role is None:
role = Object(id=role_id) role = Object(id=role_id)
role.name = e['name'] # type: ignore role.name = e['name']
data.append(role) data.append(role)
setattr(second, 'roles', data) setattr(second, 'roles', data)
class _AuditLogProxyMemberPrune:
delete_member_days: int
members_removed: int
class _AuditLogProxyMemberMoveOrMessageDelete:
channel: abc.GuildChannel
count: int
class _AuditLogProxyMemberDisconnect:
count: int
class _AuditLogProxyPinAction:
channel: abc.GuildChannel
message_id: int
class _AuditLogProxyStageInstanceAction:
channel: abc.GuildChannel
class AuditLogEntry(Hashable): class AuditLogEntry(Hashable):
r"""Represents an Audit Log entry. r"""Represents an Audit Log entry.
@@ -354,13 +234,13 @@ class AuditLogEntry(Hashable):
which actions have this field filled out. which actions have this field filled out.
""" """
def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild): def __init__(self, *, users, data, guild):
self._state = guild._state self._state = guild._state
self.guild = guild self.guild = guild
self._users = users self._users = users
self._from_data(data) self._from_data(data)
def _from_data(self, data: AuditLogEntryPayload) -> None: def _from_data(self, data):
self.action = enums.try_enum(enums.AuditLogAction, data['action_type']) self.action = enums.try_enum(enums.AuditLogAction, data['action_type'])
self.id = int(data['id']) self.id = int(data['id'])
@@ -371,58 +251,41 @@ class AuditLogEntry(Hashable):
if isinstance(self.action, enums.AuditLogAction) and self.extra: if isinstance(self.action, enums.AuditLogAction) and self.extra:
if self.action is enums.AuditLogAction.member_prune: if self.action is enums.AuditLogAction.member_prune:
# member prune has two keys with useful information # member prune has two keys with useful information
self.extra: _AuditLogProxyMemberPrune = type( self.extra = type('_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()})()
'_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()}
)()
elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete: elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete:
channel_id = int(self.extra['channel_id']) channel_id = int(self.extra['channel_id'])
elems = { elems = {
'count': int(self.extra['count']), 'count': int(self.extra['count']),
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), 'channel': self.guild.get_channel(channel_id) or Object(id=channel_id)
} }
self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type('_AuditLogProxy', (), elems)() self.extra = type('_AuditLogProxy', (), elems)()
elif self.action is enums.AuditLogAction.member_disconnect: elif self.action is enums.AuditLogAction.member_disconnect:
# The member disconnect action has a dict with some information # The member disconnect action has a dict with some information
elems = { elems = {
'count': int(self.extra['count']), 'count': int(self.extra['count']),
} }
self.extra: _AuditLogProxyMemberDisconnect = type('_AuditLogProxy', (), elems)() self.extra = type('_AuditLogProxy', (), elems)()
elif self.action.name.endswith('pin'): elif self.action.name.endswith('pin'):
# the pin actions have a dict with some information # the pin actions have a dict with some information
channel_id = int(self.extra['channel_id']) channel_id = int(self.extra['channel_id'])
message_id = int(self.extra['message_id'])
elems = { elems = {
'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), 'channel': self.guild.get_channel(channel_id) or Object(id=channel_id),
'message_id': int(self.extra['message_id']), 'message_id': message_id
} }
self.extra: _AuditLogProxyPinAction = type('_AuditLogProxy', (), elems)() self.extra = type('_AuditLogProxy', (), elems)()
elif self.action.name.startswith('overwrite_'): elif self.action.name.startswith('overwrite_'):
# the overwrite_ actions have a dict with some information # the overwrite_ actions have a dict with some information
instance_id = int(self.extra['id']) instance_id = int(self.extra['id'])
the_type = self.extra.get('type') the_type = self.extra.get('type')
if the_type == '1': if the_type == 'member':
self.extra = self._get_member(instance_id) self.extra = self._get_member(instance_id)
elif the_type == '0': else:
role = self.guild.get_role(instance_id) role = self.guild.get_role(instance_id)
if role is None: if role is None:
role = Object(id=instance_id) role = Object(id=instance_id)
role.name = self.extra.get('role_name') # type: ignore role.name = self.extra.get('role_name')
self.extra: Role = role self.extra = role
elif self.action.name.startswith('stage_instance'):
channel_id = int(self.extra['channel_id'])
elems = {'channel': self.guild.get_channel(channel_id) or Object(id=channel_id)}
self.extra: _AuditLogProxyStageInstanceAction = type('_AuditLogProxy', (), elems)()
# fmt: off
self.extra: Union[
_AuditLogProxyMemberPrune,
_AuditLogProxyMemberMoveOrMessageDelete,
_AuditLogProxyMemberDisconnect,
_AuditLogProxyPinAction,
_AuditLogProxyStageInstanceAction,
Member, User, None,
Role,
]
# fmt: on
# this key is not present when the above is present, typically. # this key is not present when the above is present, typically.
# It's a list of { new_value: a, old_value: b, key: c } # It's a list of { new_value: a, old_value: b, key: c }
@@ -431,22 +294,22 @@ class AuditLogEntry(Hashable):
# into meaningful data when requested # into meaningful data when requested
self._changes = data.get('changes', []) self._changes = data.get('changes', [])
self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) # type: ignore self.user = self._get_member(utils._get_as_snowflake(data, 'user_id'))
self._target_id = utils._get_as_snowflake(data, 'target_id') self._target_id = utils._get_as_snowflake(data, 'target_id')
def _get_member(self, user_id: int) -> Union[Member, User, None]: def _get_member(self, user_id):
return self.guild.get_member(user_id) or self._users.get(user_id) return self.guild.get_member(user_id) or self._users.get(user_id)
def __repr__(self) -> str: def __repr__(self):
return f'<AuditLogEntry id={self.id} action={self.action} user={self.user!r}>' return '<AuditLogEntry id={0.id} action={0.action} user={0.user!r}>'.format(self)
@utils.cached_property @utils.cached_property
def created_at(self) -> datetime.datetime: def created_at(self):
""":class:`datetime.datetime`: Returns the entry's creation time in UTC.""" """:class:`datetime.datetime`: Returns the entry's creation time in UTC."""
return utils.snowflake_time(self.id) return utils.snowflake_time(self.id)
@utils.cached_property @utils.cached_property
def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]: def target(self):
try: try:
converter = getattr(self, '_convert_target_' + self.action.target_type) converter = getattr(self, '_convert_target_' + self.action.target_type)
except AttributeError: except AttributeError:
@@ -455,40 +318,46 @@ class AuditLogEntry(Hashable):
return converter(self._target_id) return converter(self._target_id)
@utils.cached_property @utils.cached_property
def category(self) -> enums.AuditLogActionCategory: def category(self):
"""Optional[:class:`AuditLogActionCategory`]: The category of the action, if applicable.""" """Optional[:class:`AuditLogActionCategory`]: The category of the action, if applicable."""
return self.action.category return self.action.category
@utils.cached_property @utils.cached_property
def changes(self) -> AuditLogChanges: def changes(self):
""":class:`AuditLogChanges`: The list of changes this entry has.""" """:class:`AuditLogChanges`: The list of changes this entry has."""
obj = AuditLogChanges(self, self._changes) obj = AuditLogChanges(self, self._changes)
del self._changes del self._changes
return obj return obj
@utils.cached_property @utils.cached_property
def before(self) -> AuditLogDiff: def before(self):
""":class:`AuditLogDiff`: The target's prior state.""" """:class:`AuditLogDiff`: The target's prior state."""
return self.changes.before return self.changes.before
@utils.cached_property @utils.cached_property
def after(self) -> AuditLogDiff: def after(self):
""":class:`AuditLogDiff`: The target's subsequent state.""" """:class:`AuditLogDiff`: The target's subsequent state."""
return self.changes.after return self.changes.after
def _convert_target_guild(self, target_id: int) -> Guild: def _convert_target_guild(self, target_id):
return self.guild return self.guild
def _convert_target_channel(self, target_id: int) -> Union[abc.GuildChannel, Object]: def _convert_target_channel(self, target_id):
return self.guild.get_channel(target_id) or Object(id=target_id) ch = self.guild.get_channel(target_id)
if ch is None:
return Object(id=target_id)
return ch
def _convert_target_user(self, target_id: int) -> Union[Member, User, None]: def _convert_target_user(self, target_id):
return self._get_member(target_id) return self._get_member(target_id)
def _convert_target_role(self, target_id: int) -> Union[Role, Object]: def _convert_target_role(self, target_id):
return self.guild.get_role(target_id) or Object(id=target_id) role = self.guild.get_role(target_id)
if role is None:
return Object(id=target_id)
return role
def _convert_target_invite(self, target_id: int) -> Invite: def _convert_target_invite(self, target_id):
# invites have target_id set to null # invites have target_id set to null
# so figure out which change has the full invite data # so figure out which change has the full invite data
changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after
@@ -498,27 +367,20 @@ class AuditLogEntry(Hashable):
'max_uses': changeset.max_uses, 'max_uses': changeset.max_uses,
'code': changeset.code, 'code': changeset.code,
'temporary': changeset.temporary, 'temporary': changeset.temporary,
'channel': changeset.channel,
'uses': changeset.uses, 'uses': changeset.uses,
'guild': self.guild,
} }
obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore obj = Invite(state=self._state, data=fake_payload)
try: try:
obj.inviter = changeset.inviter obj.inviter = changeset.inviter
except AttributeError: except AttributeError:
pass pass
return obj return obj
def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]: def _convert_target_emoji(self, target_id):
return self._state.get_emoji(target_id) or Object(id=target_id) return self._state.get_emoji(target_id) or Object(id=target_id)
def _convert_target_message(self, target_id: int) -> Union[Member, User, None]: def _convert_target_message(self, target_id):
return self._get_member(target_id) return self._get_member(target_id)
def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]:
return self.guild.get_stage_instance(target_id) or Object(id=target_id)
def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]:
return self._state.get_sticker(target_id) or Object(id=target_id)
def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]:
return self.guild.get_thread(target_id) or Object(id=target_id)

View File

@@ -22,20 +22,14 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import time import time
import random import random
from typing import Callable, Generic, Literal, TypeVar, overload, Union
T = TypeVar('T', bool, Literal[True], Literal[False])
__all__ = ( __all__ = (
'ExponentialBackoff', 'ExponentialBackoff',
) )
class ExponentialBackoff(Generic[T]): class ExponentialBackoff:
"""An implementation of the exponential backoff algorithm """An implementation of the exponential backoff algorithm
Provides a convenient interface to implement an exponential backoff Provides a convenient interface to implement an exponential backoff
@@ -57,33 +51,21 @@ class ExponentialBackoff(Generic[T]):
number in between may be returned. number in between may be returned.
""" """
def __init__(self, base: int = 1, *, integral: T = False): def __init__(self, base=1, *, integral=False):
self._base: int = base self._base = base
self._exp: int = 0 self._exp = 0
self._max: int = 10 self._max = 10
self._reset_time: int = base * 2 ** 11 self._reset_time = base * 2 ** 11
self._last_invocation: float = time.monotonic() self._last_invocation = time.monotonic()
# Use our own random instance to avoid messing with global one # Use our own random instance to avoid messing with global one
rand = random.Random() rand = random.Random()
rand.seed() rand.seed()
self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore self._randfunc = rand.randrange if integral else rand.uniform
@overload def delay(self):
def delay(self: ExponentialBackoff[Literal[False]]) -> float:
...
@overload
def delay(self: ExponentialBackoff[Literal[True]]) -> int:
...
@overload
def delay(self: ExponentialBackoff[bool]) -> Union[int, float]:
...
def delay(self) -> Union[int, float]:
"""Compute the next delay """Compute the next delay
Returns the next delay to wait according to the exponential Returns the next delay to wait according to the exponential

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,383 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from .enums import try_enum, ComponentType, ButtonStyle
from .utils import get_slots, MISSING
from .partial_emoji import PartialEmoji, _EmojiTag
if TYPE_CHECKING:
from .types.components import (
Component as ComponentPayload,
ButtonComponent as ButtonComponentPayload,
SelectMenu as SelectMenuPayload,
SelectOption as SelectOptionPayload,
ActionRow as ActionRowPayload,
)
from .emoji import Emoji
__all__ = (
'Component',
'ActionRow',
'Button',
'SelectMenu',
'SelectOption',
)
C = TypeVar('C', bound='Component')
class Component:
"""Represents a Discord Bot UI Kit Component.
Currently, the only components supported by Discord are:
- :class:`ActionRow`
- :class:`Button`
- :class:`SelectMenu`
This class is abstract and cannot be instantiated.
.. versionadded:: 2.0
Attributes
------------
type: :class:`ComponentType`
The type of component.
"""
__slots__: Tuple[str, ...] = ('type',)
__repr_info__: ClassVar[Tuple[str, ...]]
type: ComponentType
def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__)
return f'<{self.__class__.__name__} {attrs}>'
@classmethod
def _raw_construct(cls: Type[C], **kwargs) -> C:
self: C = cls.__new__(cls)
for slot in get_slots(cls):
try:
value = kwargs[slot]
except KeyError:
pass
else:
setattr(self, slot, value)
return self
def to_dict(self) -> Dict[str, Any]:
raise NotImplementedError
class ActionRow(Component):
"""Represents a Discord Bot UI Kit Action Row.
This is a component that holds up to 5 children components in a row.
This inherits from :class:`Component`.
.. versionadded:: 2.0
Attributes
------------
type: :class:`ComponentType`
The type of component.
children: List[:class:`Component`]
The children components that this holds, if any.
"""
__slots__: Tuple[str, ...] = ('children',)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])]
def to_dict(self) -> ActionRowPayload:
return {
'type': int(self.type),
'components': [child.to_dict() for child in self.children],
} # type: ignore
class Button(Component):
"""Represents a button from the Discord Bot UI Kit.
This inherits from :class:`Component`.
.. note::
The user constructible and usable type to create a button is :class:`discord.ui.Button`
not this one.
.. versionadded:: 2.0
Attributes
-----------
style: :class:`.ButtonStyle`
The style of the button.
custom_id: Optional[:class:`str`]
The ID of the button that gets received during an interaction.
If this button is for a URL, it does not have a custom ID.
url: Optional[:class:`str`]
The URL this button sends you to.
disabled: :class:`bool`
Whether the button is disabled or not.
label: Optional[:class:`str`]
The label of the button, if any.
emoji: Optional[:class:`PartialEmoji`]
The emoji of the button, if available.
"""
__slots__: Tuple[str, ...] = (
'style',
'custom_id',
'url',
'disabled',
'label',
'emoji',
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: ButtonComponentPayload):
self.type: ComponentType = try_enum(ComponentType, data['type'])
self.style: ButtonStyle = try_enum(ButtonStyle, data['style'])
self.custom_id: Optional[str] = data.get('custom_id')
self.url: Optional[str] = data.get('url')
self.disabled: bool = data.get('disabled', False)
self.label: Optional[str] = data.get('label')
self.emoji: Optional[PartialEmoji]
try:
self.emoji = PartialEmoji.from_dict(data['emoji'])
except KeyError:
self.emoji = None
def to_dict(self) -> ButtonComponentPayload:
payload = {
'type': 2,
'style': int(self.style),
'label': self.label,
'disabled': self.disabled,
}
if self.custom_id:
payload['custom_id'] = self.custom_id
if self.url:
payload['url'] = self.url
if self.emoji:
payload['emoji'] = self.emoji.to_dict()
return payload # type: ignore
class SelectMenu(Component):
"""Represents a select menu from the Discord Bot UI Kit.
A select menu is functionally the same as a dropdown, however
on mobile it renders a bit differently.
.. note::
The user constructible and usable type to create a select menu is
:class:`discord.ui.Select` not this one.
.. versionadded:: 2.0
Attributes
------------
custom_id: Optional[:class:`str`]
The ID of the select menu that gets received during an interaction.
placeholder: Optional[:class:`str`]
The placeholder text that is shown if nothing is selected, if any.
min_values: :class:`int`
The minimum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25.
max_values: :class:`int`
The maximum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25.
options: List[:class:`SelectOption`]
A list of options that can be selected in this menu.
disabled: :class:`bool`
Whether the select is disabled or not.
"""
__slots__: Tuple[str, ...] = (
'custom_id',
'placeholder',
'min_values',
'max_values',
'options',
'disabled',
)
__repr_info__: ClassVar[Tuple[str, ...]] = __slots__
def __init__(self, data: SelectMenuPayload):
self.type = ComponentType.select
self.custom_id: str = data['custom_id']
self.placeholder: Optional[str] = data.get('placeholder')
self.min_values: int = data.get('min_values', 1)
self.max_values: int = data.get('max_values', 1)
self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])]
self.disabled: bool = data.get('disabled', False)
def to_dict(self) -> SelectMenuPayload:
payload: SelectMenuPayload = {
'type': self.type.value,
'custom_id': self.custom_id,
'min_values': self.min_values,
'max_values': self.max_values,
'options': [op.to_dict() for op in self.options],
'disabled': self.disabled,
}
if self.placeholder:
payload['placeholder'] = self.placeholder
return payload
class SelectOption:
"""Represents a select menu's option.
These can be created by users.
.. versionadded:: 2.0
Attributes
-----------
label: :class:`str`
The label of the option. This is displayed to users.
Can only be up to 100 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not provided when constructed then it defaults to the
label. Can only be up to 100 characters.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 100 characters.
emoji: Optional[Union[:class:`str`, :class:`Emoji`, :class:`PartialEmoji`]]
The emoji of the option, if available.
default: :class:`bool`
Whether this option is selected by default.
"""
__slots__: Tuple[str, ...] = (
'label',
'value',
'description',
'emoji',
'default',
)
def __init__(
self,
*,
label: str,
value: str = MISSING,
description: Optional[str] = None,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
default: bool = False,
) -> None:
self.label = label
self.value = label if value is MISSING else value
self.description = description
if emoji is not None:
if isinstance(emoji, str):
emoji = PartialEmoji.from_str(emoji)
elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial()
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self.emoji = emoji
self.default = default
def __repr__(self) -> str:
return (
f'<SelectOption label={self.label!r} value={self.value!r} description={self.description!r} '
f'emoji={self.emoji!r} default={self.default!r}>'
)
def __str__(self) -> str:
if self.emoji:
base = f'{self.emoji} {self.label}'
else:
base = self.label
if self.description:
return f'{base}\n{self.description}'
return base
@classmethod
def from_dict(cls, data: SelectOptionPayload) -> SelectOption:
try:
emoji = PartialEmoji.from_dict(data['emoji'])
except KeyError:
emoji = None
return cls(
label=data['label'],
value=data['value'],
description=data.get('description'),
emoji=emoji,
default=data.get('default', False),
)
def to_dict(self) -> SelectOptionPayload:
payload: SelectOptionPayload = {
'label': self.label,
'value': self.value,
'default': self.default,
}
if self.emoji:
payload['emoji'] = self.emoji.to_dict() # type: ignore
if self.description:
payload['description'] = self.description
return payload
def _component_factory(data: ComponentPayload) -> Component:
component_type = data['type']
if component_type == 1:
return ActionRow(data)
elif component_type == 2:
return Button(data) # type: ignore
elif component_type == 3:
return SelectMenu(data) # type: ignore
else:
as_enum = try_enum(ComponentType, component_type)
return Component._raw_construct(type=as_enum)

View File

@@ -22,23 +22,13 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import asyncio import asyncio
from typing import TYPE_CHECKING, TypeVar, Optional, Type
if TYPE_CHECKING:
from .abc import Messageable
from types import TracebackType
TypingT = TypeVar('TypingT', bound='Typing')
__all__ = ( __all__ = (
'Typing', 'Typing',
) )
def _typing_done_callback(fut: asyncio.Future) -> None: def _typing_done_callback(fut):
# just retrieve any exception and call it a day # just retrieve any exception and call it a day
try: try:
fut.exception() fut.exception()
@@ -46,11 +36,11 @@ def _typing_done_callback(fut: asyncio.Future) -> None:
pass pass
class Typing: class Typing:
def __init__(self, messageable: Messageable) -> None: def __init__(self, messageable):
self.loop: asyncio.AbstractEventLoop = messageable._state.loop self.loop = messageable._state.loop
self.messageable: Messageable = messageable self.messageable = messageable
async def do_typing(self) -> None: async def do_typing(self):
try: try:
channel = self._channel channel = self._channel
except AttributeError: except AttributeError:
@@ -62,26 +52,18 @@ class Typing:
await typing(channel.id) await typing(channel.id)
await asyncio.sleep(5) await asyncio.sleep(5)
def __enter__(self: TypingT) -> TypingT: def __enter__(self):
self.task: asyncio.Task = self.loop.create_task(self.do_typing()) self.task = asyncio.ensure_future(self.do_typing(), loop=self.loop)
self.task.add_done_callback(_typing_done_callback) self.task.add_done_callback(_typing_done_callback)
return self return self
def __exit__(self, def __exit__(self, exc_type, exc, tb):
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.task.cancel() self.task.cancel()
async def __aenter__(self: TypingT) -> TypingT: async def __aenter__(self):
self._channel = channel = await self.messageable._get_channel() self._channel = channel = await self.messageable._get_channel()
await channel._state.http.send_typing(channel.id) await channel._state.http.send_typing(channel.id)
return self.__enter__() return self.__enter__()
async def __aexit__(self, async def __aexit__(self, exc_type, exc, tb):
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.task.cancel() self.task.cancel()

View File

@@ -25,7 +25,8 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations from __future__ import annotations
import datetime import datetime
from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Type, TypeVar, Union import os
from typing import Any, Dict, Final, List, Protocol, TYPE_CHECKING, Type, TypeVar, Union
from . import utils from . import utils
from .colour import Colour from .colour import Colour
@@ -179,14 +180,20 @@ class Embed:
*, *,
colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed,
title: MaybeEmpty[Any] = EmptyEmbed, title: MaybeEmpty[str] = EmptyEmbed,
type: EmbedType = 'rich', type: EmbedType = 'rich',
url: MaybeEmpty[Any] = EmptyEmbed, url: MaybeEmpty[str] = EmptyEmbed,
description: MaybeEmpty[Any] = EmptyEmbed, description: MaybeEmpty[str] = EmptyEmbed,
timestamp: datetime.datetime = None, timestamp: datetime.datetime = None,
): ):
self.colour = colour if colour is not EmptyEmbed else color if colour is EmptyEmbed and color is EmptyEmbed:
colour = os.getenv("DEFAULT_EMBED_COLOR", default=EmptyEmbed)
if isinstance(colour, str):
colour = int(colour, 16)
else:
colour = colour if colour is not EmptyEmbed else color
self.colour = colour
self.title = title self.title = title
self.type = type self.type = type
self.url = url self.url = url
@@ -202,10 +209,12 @@ class Embed:
self.url = str(self.url) self.url = str(self.url)
if timestamp: if timestamp:
if timestamp.tzinfo is None:
timestamp = timestamp.astimezone()
self.timestamp = timestamp self.timestamp = timestamp
@classmethod @classmethod
def from_dict(cls: Type[E], data: Mapping[str, Any]) -> E: def from_dict(cls: Type[E], data: EmbedData) -> E:
"""Converts a :class:`dict` to a :class:`Embed` provided it is in the """Converts a :class:`dict` to a :class:`Embed` provided it is in the
format that Discord expects it to be in. format that Discord expects it to be in.
@@ -271,11 +280,11 @@ class Embed:
total += len(field['name']) + len(field['value']) total += len(field['name']) + len(field['value'])
try: try:
footer_text = self._footer['text'] footer = self._footer
except (AttributeError, KeyError): except AttributeError:
pass pass
else: else:
total += len(footer_text) total += len(footer['text'])
try: try:
author = self._author author = self._author
@@ -325,11 +334,7 @@ class Embed:
@timestamp.setter @timestamp.setter
def timestamp(self, value: MaybeEmpty[datetime.datetime]): def timestamp(self, value: MaybeEmpty[datetime.datetime]):
if isinstance(value, datetime.datetime): if isinstance(value, (datetime.datetime, _EmptyEmbed)):
if value.tzinfo is None:
value = value.astimezone()
self._timestamp = value
elif isinstance(value, _EmptyEmbed):
self._timestamp = value self._timestamp = value
else: else:
raise TypeError(f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead") raise TypeError(f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead")
@@ -344,7 +349,7 @@ class Embed:
""" """
return EmbedProxy(getattr(self, '_footer', {})) # type: ignore return EmbedProxy(getattr(self, '_footer', {})) # type: ignore
def set_footer(self: E, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E: def set_footer(self: E, *, text: MaybeEmpty[str] = EmptyEmbed, icon_url: MaybeEmpty[str] = EmptyEmbed) -> E:
"""Sets the footer for the embed content. """Sets the footer for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@@ -366,22 +371,7 @@ class Embed:
self._footer['icon_url'] = str(icon_url) self._footer['icon_url'] = str(icon_url)
return self return self
def remove_footer(self: E) -> E:
"""Clears embed's footer information.
This function returns the class instance to allow for fluent-style
chaining.
.. versionadded:: 2.0
"""
try:
del self._footer
except AttributeError:
pass
return self
@property @property
def image(self) -> _EmbedMediaProxy: def image(self) -> _EmbedMediaProxy:
"""Returns an ``EmbedProxy`` denoting the image contents. """Returns an ``EmbedProxy`` denoting the image contents.
@@ -397,7 +387,7 @@ class Embed:
""" """
return EmbedProxy(getattr(self, '_image', {})) # type: ignore return EmbedProxy(getattr(self, '_image', {})) # type: ignore
def set_image(self: E, *, url: MaybeEmpty[Any]) -> E: def set_image(self: E, *, url: MaybeEmpty[str]) -> E:
"""Sets the image for the embed content. """Sets the image for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@@ -439,7 +429,7 @@ class Embed:
""" """
return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore
def set_thumbnail(self: E, *, url: MaybeEmpty[Any]) -> E: def set_thumbnail(self: E, *, url: MaybeEmpty[str]) -> E:
"""Sets the thumbnail for the embed content. """Sets the thumbnail for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@@ -500,7 +490,7 @@ class Embed:
""" """
return EmbedProxy(getattr(self, '_author', {})) # type: ignore return EmbedProxy(getattr(self, '_author', {})) # type: ignore
def set_author(self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E: def set_author(self: E, *, name: str, url: MaybeEmpty[str] = EmptyEmbed, icon_url: MaybeEmpty[str] = EmptyEmbed) -> E:
"""Sets the author for the embed content. """Sets the author for the embed content.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@@ -545,7 +535,7 @@ class Embed:
@property @property
def fields(self) -> List[_EmbedFieldProxy]: def fields(self) -> List[_EmbedFieldProxy]:
"""List[Union[``EmbedProxy``, :attr:`Empty`]]: Returns a :class:`list` of ``EmbedProxy`` denoting the field contents. """Union[List[:class:`EmbedProxy`], :attr:`Empty`]: Returns a :class:`list` of ``EmbedProxy`` denoting the field contents.
See :meth:`add_field` for possible values you can access. See :meth:`add_field` for possible values you can access.
@@ -553,7 +543,7 @@ class Embed:
""" """
return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore
def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E: def add_field(self: E, *, name: str, value: str, inline: bool = True) -> E:
"""Adds a field to the embed object. """Adds a field to the embed object.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@@ -582,7 +572,7 @@ class Embed:
return self return self
def insert_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E: def insert_field_at(self: E, index: int, *, name: str, value: str, inline: bool = True) -> E:
"""Inserts a field before a specified index to the embed. """Inserts a field before a specified index to the embed.
This function returns the class instance to allow for fluent-style This function returns the class instance to allow for fluent-style
@@ -643,7 +633,7 @@ class Embed:
except (AttributeError, IndexError): except (AttributeError, IndexError):
pass pass
def set_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E: def set_field_at(self: E, index: int, *, name: str, value: str, inline: bool = True) -> E:
"""Modifies a field to the embed object. """Modifies a field to the embed object.
The index must point to a valid pre-existing field. The index must point to a valid pre-existing field.

View File

@@ -22,28 +22,16 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from .asset import Asset
from typing import Any, Iterator, List, Optional, TYPE_CHECKING, Tuple from . import utils
from .partial_emoji import _EmojiTag
from .asset import Asset, AssetMixin
from .utils import SnowflakeList, snowflake_time, MISSING
from .partial_emoji import _EmojiTag, PartialEmoji
from .user import User from .user import User
__all__ = ( __all__ = (
'Emoji', 'Emoji',
) )
if TYPE_CHECKING: class Emoji(_EmojiTag):
from .types.emoji import Emoji as EmojiPayload
from .guild import Guild
from .state import ConnectionState
from .abc import Snowflake
from .role import Role
from datetime import datetime
class Emoji(_EmojiTag, AssetMixin):
"""Represents a custom emoji. """Represents a custom emoji.
Depending on the way this object was created, some of the attributes can Depending on the way this object was created, some of the attributes can
@@ -92,76 +80,71 @@ class Emoji(_EmojiTag, AssetMixin):
The user that created the emoji. This can only be retrieved using :meth:`Guild.fetch_emoji` and The user that created the emoji. This can only be retrieved using :meth:`Guild.fetch_emoji` and
having the :attr:`~Permissions.manage_emojis` permission. having the :attr:`~Permissions.manage_emojis` permission.
""" """
__slots__ = ('require_colons', 'animated', 'managed', 'id', 'name', '_roles', 'guild_id',
'_state', 'user', 'available')
__slots__: Tuple[str, ...] = ( def __init__(self, *, guild, state, data):
'require_colons', self.guild_id = guild.id
'animated', self._state = state
'managed',
'id',
'name',
'_roles',
'guild_id',
'_state',
'user',
'available',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload):
self.guild_id: int = guild.id
self._state: ConnectionState = state
self._from_data(data) self._from_data(data)
def _from_data(self, emoji: EmojiPayload): def _from_data(self, emoji):
self.require_colons: bool = emoji.get('require_colons', False) self.require_colons = emoji['require_colons']
self.managed: bool = emoji.get('managed', False) self.managed = emoji['managed']
self.id: int = int(emoji['id']) # type: ignore self.id = int(emoji['id'])
self.name: str = emoji['name'] # type: ignore self.name = emoji['name']
self.animated: bool = emoji.get('animated', False) self.animated = emoji.get('animated', False)
self.available: bool = emoji.get('available', True) self.available = emoji.get('available', True)
self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', []))) self._roles = utils.SnowflakeList(map(int, emoji.get('roles', [])))
user = emoji.get('user') user = emoji.get('user')
self.user: Optional[User] = User(state=self._state, data=user) if user else None self.user = User(state=self._state, data=user) if user else None
def _to_partial(self) -> PartialEmoji: def _iterator(self):
return PartialEmoji(name=self.name, animated=self.animated, id=self.id)
def __iter__(self) -> Iterator[Tuple[str, Any]]:
for attr in self.__slots__: for attr in self.__slots__:
if attr[0] != '_': if attr[0] != '_':
value = getattr(self, attr, None) value = getattr(self, attr, None)
if value is not None: if value is not None:
yield (attr, value) yield (attr, value)
def __str__(self) -> str: def __iter__(self):
return self._iterator()
def __str__(self):
if self.animated: if self.animated:
return f'<a:{self.name}:{self.id}>' return '<a:{0.name}:{0.id}>'.format(self)
return f'<:{self.name}:{self.id}>' return "<:{0.name}:{0.id}>".format(self)
def __repr__(self) -> str: def __int__(self):
return f'<Emoji id={self.id} name={self.name!r} animated={self.animated} managed={self.managed}>' return self.id
def __eq__(self, other: Any) -> bool: def __repr__(self):
return '<Emoji id={0.id} name={0.name!r} animated={0.animated} managed={0.managed}>'.format(self)
def __eq__(self, other):
return isinstance(other, _EmojiTag) and self.id == other.id return isinstance(other, _EmojiTag) and self.id == other.id
def __ne__(self, other: Any) -> bool: def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self):
return self.id >> 22 return self.id >> 22
@property @property
def created_at(self) -> datetime: def created_at(self):
""":class:`datetime.datetime`: Returns the emoji's creation time in UTC.""" """:class:`datetime.datetime`: Returns the emoji's creation time in UTC."""
return snowflake_time(self.id) return utils.snowflake_time(self.id)
@property @property
def url(self) -> str: def url(self):
""":class:`str`: Returns the URL of the emoji.""" """:class:`Asset`: Returns the asset of the emoji.
fmt = 'gif' if self.animated else 'png'
return f'{Asset.BASE}/emojis/{self.id}.{fmt}' This is equivalent to calling :meth:`url_as` with
the default parameters (i.e. png/gif detection).
"""
return self.url_as(format=None)
@property @property
def roles(self) -> List[Role]: def roles(self):
"""List[:class:`Role`]: A :class:`list` of roles that is allowed to use this emoji. """List[:class:`Role`]: A :class:`list` of roles that is allowed to use this emoji.
If roles is empty, the emoji is unrestricted. If roles is empty, the emoji is unrestricted.
@@ -173,11 +156,44 @@ class Emoji(_EmojiTag, AssetMixin):
return [role for role in guild.roles if self._roles.has(role.id)] return [role for role in guild.roles if self._roles.has(role.id)]
@property @property
def guild(self) -> Guild: def guild(self):
""":class:`Guild`: The guild this emoji belongs to.""" """:class:`Guild`: The guild this emoji belongs to."""
return self._state._get_guild(self.guild_id) return self._state._get_guild(self.guild_id)
def is_usable(self) -> bool:
def url_as(self, *, format=None, static_format="png"):
"""Returns an :class:`Asset` for the emoji's url.
The format must be one of 'webp', 'jpeg', 'jpg', 'png' or 'gif'.
'gif' is only valid for animated emojis.
.. versionadded:: 1.6
Parameters
-----------
format: Optional[:class:`str`]
The format to attempt to convert the emojis to.
If the format is ``None``, then it is automatically
detected as either 'gif' or static_format, depending on whether the
emoji is animated or not.
static_format: Optional[:class:`str`]
Format to attempt to convert only non-animated emoji's to.
Defaults to 'png'
Raises
-------
InvalidArgument
Bad image format passed to ``format`` or ``static_format``.
Returns
--------
:class:`Asset`
The resulting CDN asset.
"""
return Asset._from_emoji(self._state, self, format=format, static_format=static_format)
def is_usable(self):
""":class:`bool`: Whether the bot can use this emoji. """:class:`bool`: Whether the bot can use this emoji.
.. versionadded:: 1.3 .. versionadded:: 1.3
@@ -189,7 +205,7 @@ class Emoji(_EmojiTag, AssetMixin):
emoji_roles, my_roles = self._roles, self.guild.me._roles emoji_roles, my_roles = self._roles, self.guild.me._roles
return any(my_roles.has(role_id) for role_id in emoji_roles) return any(my_roles.has(role_id) for role_id in emoji_roles)
async def delete(self, *, reason: Optional[str] = None) -> None: async def delete(self, *, reason=None):
"""|coro| """|coro|
Deletes the custom emoji. Deletes the custom emoji.
@@ -212,7 +228,7 @@ class Emoji(_EmojiTag, AssetMixin):
await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason) await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason)
async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji: async def edit(self, *, name=None, roles=None, reason=None):
r"""|coro| r"""|coro|
Edits the custom emoji. Edits the custom emoji.
@@ -220,15 +236,12 @@ class Emoji(_EmojiTag, AssetMixin):
You must have :attr:`~Permissions.manage_emojis` permission to You must have :attr:`~Permissions.manage_emojis` permission to
do this. do this.
.. versionchanged:: 2.0
The newly updated emoji is returned.
Parameters Parameters
----------- -----------
name: :class:`str` name: :class:`str`
The new emoji name. The new emoji name.
roles: Optional[List[:class:`~discord.abc.Snowflake`]] roles: Optional[list[:class:`Role`]]
A list of roles that can use this emoji. An empty list can be passed to make it available to everyone. A :class:`list` of :class:`Role`\s that can use this emoji. Leave empty to make it available to everyone.
reason: Optional[:class:`str`] reason: Optional[:class:`str`]
The reason for editing this emoji. Shows up on the audit log. The reason for editing this emoji. Shows up on the audit log.
@@ -238,18 +251,9 @@ class Emoji(_EmojiTag, AssetMixin):
You are not allowed to edit emojis. You are not allowed to edit emojis.
HTTPException HTTPException
An error occurred editing the emoji. An error occurred editing the emoji.
Returns
--------
:class:`Emoji`
The newly updated emoji.
""" """
payload = {} name = name or self.name
if name is not MISSING: if roles:
payload['name'] = name roles = [role.id for role in roles]
if roles is not MISSING: await self._state.http.edit_custom_emoji(self.guild.id, self.id, name=name, roles=roles, reason=reason)
payload['roles'] = [role.id for role in roles]
data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason)
return Emoji(guild=self.guild, data=data, state=self._state)

View File

@@ -24,7 +24,7 @@ DEALINGS IN THE SOFTWARE.
import types import types
from collections import namedtuple from collections import namedtuple
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar from typing import Any, TYPE_CHECKING, Type, TypeVar
__all__ = ( __all__ = (
'Enum', 'Enum',
@@ -46,46 +46,24 @@ __all__ = (
'ExpireBehaviour', 'ExpireBehaviour',
'ExpireBehavior', 'ExpireBehavior',
'StickerType', 'StickerType',
'StickerFormatType',
'InviteTarget',
'VideoQualityMode',
'ComponentType',
'ButtonStyle',
'StagePrivacyLevel',
'InteractionType',
'InteractionResponseType',
'NSFWLevel',
) )
def _create_value_cls(name):
def _create_value_cls(name, comparable):
cls = namedtuple('_EnumValue_' + name, 'name value') cls = namedtuple('_EnumValue_' + name, 'name value')
cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>' cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>'
cls.__str__ = lambda self: f'{name}.{self.name}' cls.__str__ = lambda self: f'{name}.{self.name}'
if comparable:
cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value
cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value
cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value
cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value
return cls return cls
def _is_descriptor(obj): def _is_descriptor(obj):
return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__') return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__')
class EnumMeta(type): class EnumMeta(type):
if TYPE_CHECKING: def __new__(cls, name, bases, attrs):
__name__: ClassVar[str]
_enum_member_names_: ClassVar[List[str]]
_enum_member_map_: ClassVar[Dict[str, Any]]
_enum_value_map_: ClassVar[Dict[Any, Any]]
def __new__(cls, name, bases, attrs, *, comparable: bool = False):
value_mapping = {} value_mapping = {}
member_mapping = {} member_mapping = {}
member_names = [] member_names = []
value_cls = _create_value_cls(name, comparable) value_cls = _create_value_cls(name)
for key, value in list(attrs.items()): for key, value in list(attrs.items()):
is_descriptor = _is_descriptor(value) is_descriptor = _is_descriptor(value)
if key[0] == '_' and not is_descriptor: if key[0] == '_' and not is_descriptor:
@@ -115,7 +93,7 @@ class EnumMeta(type):
attrs['_enum_member_names_'] = member_names attrs['_enum_member_names_'] = member_names
attrs['_enum_value_cls_'] = value_cls attrs['_enum_value_cls_'] = value_cls
actual_cls = super().__new__(cls, name, bases, attrs) actual_cls = super().__new__(cls, name, bases, attrs)
value_cls._actual_enum_cls_ = actual_cls # type: ignore value_cls._actual_enum_cls_ = actual_cls
return actual_cls return actual_cls
def __iter__(cls): def __iter__(cls):
@@ -157,11 +135,9 @@ class EnumMeta(type):
except AttributeError: except AttributeError:
return False return False
if TYPE_CHECKING: if TYPE_CHECKING:
from enum import Enum from enum import Enum
else: else:
class Enum(metaclass=EnumMeta): class Enum(metaclass=EnumMeta):
@classmethod @classmethod
def try_value(cls, value): def try_value(cls, value):
@@ -170,84 +146,72 @@ else:
except (KeyError, TypeError): except (KeyError, TypeError):
return value return value
class ChannelType(Enum): class ChannelType(Enum):
text = 0 text = 0
private = 1 private = 1
voice = 2 voice = 2
group = 3 group = 3
category = 4 category = 4
news = 5 news = 5
store = 6 store = 6
news_thread = 10
public_thread = 11
private_thread = 12
stage_voice = 13 stage_voice = 13
def __str__(self): def __str__(self):
return self.name return self.name
class MessageType(Enum): class MessageType(Enum):
default = 0 default = 0
recipient_add = 1 recipient_add = 1
recipient_remove = 2 recipient_remove = 2
call = 3 call = 3
channel_name_change = 4 channel_name_change = 4
channel_icon_change = 5 channel_icon_change = 5
pins_add = 6 pins_add = 6
new_member = 7 new_member = 7
premium_guild_subscription = 8 premium_guild_subscription = 8
premium_guild_tier_1 = 9 premium_guild_tier_1 = 9
premium_guild_tier_2 = 10 premium_guild_tier_2 = 10
premium_guild_tier_3 = 11 premium_guild_tier_3 = 11
channel_follow_add = 12 channel_follow_add = 12
guild_stream = 13 guild_stream = 13
guild_discovery_disqualified = 14 guild_discovery_disqualified = 14
guild_discovery_requalified = 15 guild_discovery_requalified = 15
guild_discovery_grace_period_initial_warning = 16 guild_discovery_grace_period_initial_warning = 16
guild_discovery_grace_period_final_warning = 17 guild_discovery_grace_period_final_warning = 17
thread_created = 18
reply = 19
application_command = 20
thread_starter_message = 21
guild_invite_reminder = 22
class VoiceRegion(Enum): class VoiceRegion(Enum):
us_west = 'us-west' us_west = 'us-west'
us_east = 'us-east' us_east = 'us-east'
us_south = 'us-south' us_south = 'us-south'
us_central = 'us-central' us_central = 'us-central'
eu_west = 'eu-west' eu_west = 'eu-west'
eu_central = 'eu-central' eu_central = 'eu-central'
singapore = 'singapore' singapore = 'singapore'
london = 'london' london = 'london'
sydney = 'sydney' sydney = 'sydney'
amsterdam = 'amsterdam' amsterdam = 'amsterdam'
frankfurt = 'frankfurt' frankfurt = 'frankfurt'
brazil = 'brazil' brazil = 'brazil'
hongkong = 'hongkong' hongkong = 'hongkong'
russia = 'russia' russia = 'russia'
japan = 'japan' japan = 'japan'
southafrica = 'southafrica' southafrica = 'southafrica'
south_korea = 'south-korea' south_korea = 'south-korea'
india = 'india' india = 'india'
europe = 'europe' europe = 'europe'
dubai = 'dubai' dubai = 'dubai'
vip_us_east = 'vip-us-east' vip_us_east = 'vip-us-east'
vip_us_west = 'vip-us-west' vip_us_west = 'vip-us-west'
vip_amsterdam = 'vip-amsterdam' vip_amsterdam = 'vip-amsterdam'
def __str__(self): def __str__(self):
return self.value return self.value
class SpeakingState(Enum): class SpeakingState(Enum):
none = 0 none = 0
voice = 1 voice = 1
soundshare = 2 soundshare = 2
priority = 4 priority = 4
def __str__(self): def __str__(self):
return self.name return self.name
@@ -255,27 +219,27 @@ class SpeakingState(Enum):
def __int__(self): def __int__(self):
return self.value return self.value
class VerificationLevel(Enum):
class VerificationLevel(Enum, comparable=True): none = 0
none = 0 low = 1
low = 1 medium = 2
medium = 2 high = 3
high = 3 table_flip = 3
highest = 4 extreme = 4
double_table_flip = 4
very_high = 4
def __str__(self): def __str__(self):
return self.name return self.name
class ContentFilter(Enum):
class ContentFilter(Enum, comparable=True): disabled = 0
disabled = 0 no_role = 1
no_role = 1
all_members = 2 all_members = 2
def __str__(self): def __str__(self):
return self.name return self.name
class Status(Enum): class Status(Enum):
online = 'online' online = 'online'
offline = 'offline' offline = 'offline'
@@ -287,32 +251,27 @@ class Status(Enum):
def __str__(self): def __str__(self):
return self.value return self.value
class DefaultAvatar(Enum): class DefaultAvatar(Enum):
blurple = 0 blurple = 0
grey = 1 grey = 1
gray = 1 gray = 1
green = 2 green = 2
orange = 3 orange = 3
red = 4 red = 4
def __str__(self): def __str__(self):
return self.name return self.name
class NotificationLevel(Enum):
class NotificationLevel(Enum, comparable=True): all_messages = 0
all_messages = 0
only_mentions = 1 only_mentions = 1
class AuditLogActionCategory(Enum): class AuditLogActionCategory(Enum):
create = 1 create = 1
delete = 2 delete = 2
update = 3 update = 3
class AuditLogAction(Enum): class AuditLogAction(Enum):
# fmt: off
guild_update = 1 guild_update = 1
channel_create = 10 channel_create = 10
channel_update = 11 channel_update = 11
@@ -348,71 +307,50 @@ class AuditLogAction(Enum):
integration_create = 80 integration_create = 80
integration_update = 81 integration_update = 81
integration_delete = 82 integration_delete = 82
stage_instance_create = 83
stage_instance_update = 84
stage_instance_delete = 85
sticker_create = 90
sticker_update = 91
sticker_delete = 92
thread_create = 110
thread_update = 111
thread_delete = 112
# fmt: on
@property @property
def category(self) -> Optional[AuditLogActionCategory]: def category(self):
# fmt: off lookup = {
lookup: Dict[AuditLogAction, Optional[AuditLogActionCategory]] = { AuditLogAction.guild_update: AuditLogActionCategory.update,
AuditLogAction.guild_update: AuditLogActionCategory.update, AuditLogAction.channel_create: AuditLogActionCategory.create,
AuditLogAction.channel_create: AuditLogActionCategory.create, AuditLogAction.channel_update: AuditLogActionCategory.update,
AuditLogAction.channel_update: AuditLogActionCategory.update, AuditLogAction.channel_delete: AuditLogActionCategory.delete,
AuditLogAction.channel_delete: AuditLogActionCategory.delete, AuditLogAction.overwrite_create: AuditLogActionCategory.create,
AuditLogAction.overwrite_create: AuditLogActionCategory.create, AuditLogAction.overwrite_update: AuditLogActionCategory.update,
AuditLogAction.overwrite_update: AuditLogActionCategory.update, AuditLogAction.overwrite_delete: AuditLogActionCategory.delete,
AuditLogAction.overwrite_delete: AuditLogActionCategory.delete, AuditLogAction.kick: None,
AuditLogAction.kick: None, AuditLogAction.member_prune: None,
AuditLogAction.member_prune: None, AuditLogAction.ban: None,
AuditLogAction.ban: None, AuditLogAction.unban: None,
AuditLogAction.unban: None, AuditLogAction.member_update: AuditLogActionCategory.update,
AuditLogAction.member_update: AuditLogActionCategory.update, AuditLogAction.member_role_update: AuditLogActionCategory.update,
AuditLogAction.member_role_update: AuditLogActionCategory.update, AuditLogAction.member_move: None,
AuditLogAction.member_move: None, AuditLogAction.member_disconnect: None,
AuditLogAction.member_disconnect: None, AuditLogAction.bot_add: None,
AuditLogAction.bot_add: None, AuditLogAction.role_create: AuditLogActionCategory.create,
AuditLogAction.role_create: AuditLogActionCategory.create, AuditLogAction.role_update: AuditLogActionCategory.update,
AuditLogAction.role_update: AuditLogActionCategory.update, AuditLogAction.role_delete: AuditLogActionCategory.delete,
AuditLogAction.role_delete: AuditLogActionCategory.delete, AuditLogAction.invite_create: AuditLogActionCategory.create,
AuditLogAction.invite_create: AuditLogActionCategory.create, AuditLogAction.invite_update: AuditLogActionCategory.update,
AuditLogAction.invite_update: AuditLogActionCategory.update, AuditLogAction.invite_delete: AuditLogActionCategory.delete,
AuditLogAction.invite_delete: AuditLogActionCategory.delete, AuditLogAction.webhook_create: AuditLogActionCategory.create,
AuditLogAction.webhook_create: AuditLogActionCategory.create, AuditLogAction.webhook_update: AuditLogActionCategory.update,
AuditLogAction.webhook_update: AuditLogActionCategory.update, AuditLogAction.webhook_delete: AuditLogActionCategory.delete,
AuditLogAction.webhook_delete: AuditLogActionCategory.delete, AuditLogAction.emoji_create: AuditLogActionCategory.create,
AuditLogAction.emoji_create: AuditLogActionCategory.create, AuditLogAction.emoji_update: AuditLogActionCategory.update,
AuditLogAction.emoji_update: AuditLogActionCategory.update, AuditLogAction.emoji_delete: AuditLogActionCategory.delete,
AuditLogAction.emoji_delete: AuditLogActionCategory.delete, AuditLogAction.message_delete: AuditLogActionCategory.delete,
AuditLogAction.message_delete: AuditLogActionCategory.delete, AuditLogAction.message_bulk_delete: AuditLogActionCategory.delete,
AuditLogAction.message_bulk_delete: AuditLogActionCategory.delete, AuditLogAction.message_pin: None,
AuditLogAction.message_pin: None, AuditLogAction.message_unpin: None,
AuditLogAction.message_unpin: None, AuditLogAction.integration_create: AuditLogActionCategory.create,
AuditLogAction.integration_create: AuditLogActionCategory.create, AuditLogAction.integration_update: AuditLogActionCategory.update,
AuditLogAction.integration_update: AuditLogActionCategory.update, AuditLogAction.integration_delete: AuditLogActionCategory.delete,
AuditLogAction.integration_delete: AuditLogActionCategory.delete,
AuditLogAction.stage_instance_create: AuditLogActionCategory.create,
AuditLogAction.stage_instance_update: AuditLogActionCategory.update,
AuditLogAction.stage_instance_delete: AuditLogActionCategory.delete,
AuditLogAction.sticker_create: AuditLogActionCategory.create,
AuditLogAction.sticker_update: AuditLogActionCategory.update,
AuditLogAction.sticker_delete: AuditLogActionCategory.delete,
AuditLogAction.thread_create: AuditLogActionCategory.create,
AuditLogAction.thread_update: AuditLogActionCategory.update,
AuditLogAction.thread_delete: AuditLogActionCategory.delete,
} }
# fmt: on
return lookup[self] return lookup[self]
@property @property
def target_type(self) -> Optional[str]: def target_type(self):
v = self.value v = self.value
if v == -1: if v == -1:
return 'all' return 'all'
@@ -434,15 +372,8 @@ class AuditLogAction(Enum):
return 'channel' return 'channel'
elif v < 80: elif v < 80:
return 'message' return 'message'
elif v < 83:
return 'integration'
elif v < 90: elif v < 90:
return 'stage_instance' return 'integration'
elif v < 93:
return 'sticker'
elif v < 113:
return 'thread'
class UserFlags(Enum): class UserFlags(Enum):
staff = 1 staff = 1
@@ -461,8 +392,6 @@ class UserFlags(Enum):
bug_hunter_level_2 = 16384 bug_hunter_level_2 = 16384
verified_bot = 65536 verified_bot = 65536
verified_bot_developer = 131072 verified_bot_developer = 131072
discord_certified_moderator = 262144
class ActivityType(Enum): class ActivityType(Enum):
unknown = -1 unknown = -1
@@ -476,128 +405,36 @@ class ActivityType(Enum):
def __int__(self): def __int__(self):
return self.value return self.value
class TeamMembershipState(Enum): class TeamMembershipState(Enum):
invited = 1 invited = 1
accepted = 2 accepted = 2
class WebhookType(Enum): class WebhookType(Enum):
incoming = 1 incoming = 1
channel_follower = 2 channel_follower = 2
application = 3
class ExpireBehaviour(Enum): class ExpireBehaviour(Enum):
remove_role = 0 remove_role = 0
kick = 1 kick = 1
ExpireBehavior = ExpireBehaviour ExpireBehavior = ExpireBehaviour
class StickerType(Enum): class StickerType(Enum):
standard = 1
guild = 2
class StickerFormatType(Enum):
png = 1 png = 1
apng = 2 apng = 2
lottie = 3 lottie = 3
@property
def file_extension(self) -> str:
# fmt: off
lookup: Dict[StickerFormatType, str] = {
StickerFormatType.png: 'png',
StickerFormatType.apng: 'png',
StickerFormatType.lottie: 'json',
}
# fmt: on
return lookup[self]
class InviteTarget(Enum):
unknown = 0
stream = 1
embedded_application = 2
class InteractionType(Enum): class InteractionType(Enum):
ping = 1 ping = 1
application_command = 2 application_command = 2
component = 3
class InteractionResponseType(Enum):
pong = 1
# ack = 2 (deprecated)
# channel_message = 3 (deprecated)
channel_message = 4 # (with source)
deferred_channel_message = 5 # (with source)
deferred_message_update = 6 # for components
message_update = 7 # for components
class VideoQualityMode(Enum):
auto = 1
full = 2
def __int__(self):
return self.value
class ComponentType(Enum):
action_row = 1
button = 2
select = 3
def __int__(self):
return self.value
class ButtonStyle(Enum):
primary = 1
secondary = 2
success = 3
danger = 4
link = 5
# Aliases
blurple = 1
grey = 2
gray = 2
green = 3
red = 4
url = 5
def __int__(self):
return self.value
class StagePrivacyLevel(Enum):
public = 1
closed = 2
guild_only = 2
class NSFWLevel(Enum, comparable=True):
default = 0
explicit = 1
safe = 2
age_restricted = 3
T = TypeVar('T') T = TypeVar('T')
def create_unknown_value(cls: Type[T], val: Any) -> T: def create_unknown_value(cls: Type[T], val: Any) -> T:
value_cls = cls._enum_value_cls_ # type: ignore value_cls = cls._enum_value_cls_ # type: ignore
name = f'unknown_{val}' name = f'unknown_{val}'
return value_cls(name=name, value=val) return value_cls(name=name, value=val)
def try_enum(cls: Type[T], val: Any) -> T: def try_enum(cls: Type[T], val: Any) -> T:
"""A function that tries to turn the value into enum ``cls``. """A function that tries to turn the value into enum ``cls``.
@@ -605,6 +442,6 @@ def try_enum(cls: Type[T], val: Any) -> T:
""" """
try: try:
return cls._enum_value_map_[val] # type: ignore return cls._enum_value_map_[val] # type: ignore
except (KeyError, TypeError, AttributeError): except (KeyError, TypeError, AttributeError):
return create_unknown_value(cls, val) return create_unknown_value(cls, val)

View File

@@ -22,21 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Dict, List, Optional, TYPE_CHECKING, Any, Tuple, Union
if TYPE_CHECKING:
from aiohttp import ClientResponse, ClientWebSocketResponse
try:
from requests import Response
_ResponseType = Union[ClientResponse, Response]
except ModuleNotFoundError:
_ResponseType = ClientResponse
from .interactions import Interaction
__all__ = ( __all__ = (
'DiscordException', 'DiscordException',
'ClientException', 'ClientException',
@@ -51,52 +36,44 @@ __all__ = (
'LoginFailure', 'LoginFailure',
'ConnectionClosed', 'ConnectionClosed',
'PrivilegedIntentsRequired', 'PrivilegedIntentsRequired',
'InteractionResponded',
) )
class DiscordException(Exception): class DiscordException(Exception):
"""Base exception class for discord.py """Base exception class for discord.py
Ideally speaking, this could be caught to handle any exceptions raised from this library. Ideally speaking, this could be caught to handle any exceptions thrown from this library.
""" """
pass pass
class ClientException(DiscordException): class ClientException(DiscordException):
"""Exception that's raised when an operation in the :class:`Client` fails. """Exception that's thrown when an operation in the :class:`Client` fails.
These are usually for exceptions that happened due to user input. These are usually for exceptions that happened due to user input.
""" """
pass pass
class NoMoreItems(DiscordException): class NoMoreItems(DiscordException):
"""Exception that is raised when an async iteration operation has no more items.""" """Exception that is thrown when an async iteration operation has no more
items."""
pass pass
class GatewayNotFound(DiscordException): class GatewayNotFound(DiscordException):
"""An exception that is raised when the gateway for Discord could not be found""" """An exception that is usually thrown when the gateway hub
for the :class:`Client` websocket is not found."""
def __init__(self): def __init__(self):
message = 'The gateway to connect to discord was not found.' message = 'The gateway to connect to discord was not found.'
super().__init__(message) super().__init__(message)
def flatten_error_dict(d, key=''):
def _flatten_error_dict(d: Dict[str, Any], key: str = '') -> Dict[str, str]: items = []
items: List[Tuple[str, str]] = []
for k, v in d.items(): for k, v in d.items():
new_key = key + '.' + k if key else k new_key = key + '.' + k if key else k
if isinstance(v, dict): if isinstance(v, dict):
try: try:
_errors: List[Dict[str, Any]] = v['_errors'] _errors = v['_errors']
except KeyError: except KeyError:
items.extend(_flatten_error_dict(v, new_key).items()) items.extend(flatten_error_dict(v, new_key).items())
else: else:
items.append((new_key, ' '.join(x.get('message', '') for x in _errors))) items.append((new_key, ' '.join(x.get('message', '') for x in _errors)))
else: else:
@@ -104,9 +81,8 @@ def _flatten_error_dict(d: Dict[str, Any], key: str = '') -> Dict[str, str]:
return dict(items) return dict(items)
class HTTPException(DiscordException): class HTTPException(DiscordException):
"""Exception that's raised when an HTTP request operation fails. """Exception that's thrown when an HTTP request operation fails.
Attributes Attributes
------------ ------------
@@ -123,23 +99,21 @@ class HTTPException(DiscordException):
The Discord specific error code for the failure. The Discord specific error code for the failure.
""" """
def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]): def __init__(self, response, message):
self.response: _ResponseType = response self.response = response
self.status: int = response.status # type: ignore self.status = response.status
self.code: int
self.text: str
if isinstance(message, dict): if isinstance(message, dict):
self.code = message.get('code', 0) self.code = message.get('code', 0)
base = message.get('message', '') base = message.get('message', '')
errors = message.get('errors') errors = message.get('errors')
if errors: if errors:
errors = _flatten_error_dict(errors) errors = flatten_error_dict(errors)
helpful = '\n'.join('In %s: %s' % t for t in errors.items()) helpful = '\n'.join('In %s: %s' % t for t in errors.items())
self.text = base + '\n' + helpful self.text = base + '\n' + helpful
else: else:
self.text = base self.text = base
else: else:
self.text = message or '' self.text = message
self.code = 0 self.code = 0
fmt = '{0.status} {0.reason} (error code: {1})' fmt = '{0.status} {0.reason} (error code: {1})'
@@ -148,67 +122,54 @@ class HTTPException(DiscordException):
super().__init__(fmt.format(self.response, self.code, self.text)) super().__init__(fmt.format(self.response, self.code, self.text))
class Forbidden(HTTPException): class Forbidden(HTTPException):
"""Exception that's raised for when status code 403 occurs. """Exception that's thrown for when status code 403 occurs.
Subclass of :exc:`HTTPException` Subclass of :exc:`HTTPException`
""" """
pass pass
class NotFound(HTTPException): class NotFound(HTTPException):
"""Exception that's raised for when status code 404 occurs. """Exception that's thrown for when status code 404 occurs.
Subclass of :exc:`HTTPException` Subclass of :exc:`HTTPException`
""" """
pass pass
class DiscordServerError(HTTPException): class DiscordServerError(HTTPException):
"""Exception that's raised for when a 500 range status code occurs. """Exception that's thrown for when a 500 range status code occurs.
Subclass of :exc:`HTTPException`. Subclass of :exc:`HTTPException`.
.. versionadded:: 1.5 .. versionadded:: 1.5
""" """
pass pass
class InvalidData(ClientException): class InvalidData(ClientException):
"""Exception that's raised when the library encounters unknown """Exception that's raised when the library encounters unknown
or invalid data from Discord. or invalid data from Discord.
""" """
pass pass
class InvalidArgument(ClientException): class InvalidArgument(ClientException):
"""Exception that's raised when an argument to a function """Exception that's thrown when an argument to a function
is invalid some way (e.g. wrong value or wrong type). is invalid some way (e.g. wrong value or wrong type).
This could be considered the analogous of ``ValueError`` and This could be considered the analogous of ``ValueError`` and
``TypeError`` except inherited from :exc:`ClientException` and thus ``TypeError`` except inherited from :exc:`ClientException` and thus
:exc:`DiscordException`. :exc:`DiscordException`.
""" """
pass pass
class LoginFailure(ClientException): class LoginFailure(ClientException):
"""Exception that's raised when the :meth:`Client.login` function """Exception that's thrown when the :meth:`Client.login` function
fails to log you in from improper credentials or some other misc. fails to log you in from improper credentials or some other misc.
failure. failure.
""" """
pass pass
class ConnectionClosed(ClientException): class ConnectionClosed(ClientException):
"""Exception that's raised when the gateway connection is """Exception that's thrown when the gateway connection is
closed for reasons that could not be handled internally. closed for reasons that could not be handled internally.
Attributes Attributes
@@ -220,19 +181,17 @@ class ConnectionClosed(ClientException):
shard_id: Optional[:class:`int`] shard_id: Optional[:class:`int`]
The shard ID that got closed if applicable. The shard ID that got closed if applicable.
""" """
def __init__(self, socket, *, shard_id, code=None):
def __init__(self, socket: ClientWebSocketResponse, *, shard_id: Optional[int], code: Optional[int] = None):
# This exception is just the same exception except # This exception is just the same exception except
# reconfigured to subclass ClientException for users # reconfigured to subclass ClientException for users
self.code: int = code or socket.close_code or -1 self.code = code or socket.close_code
# aiohttp doesn't seem to consistently provide close reason # aiohttp doesn't seem to consistently provide close reason
self.reason: str = '' self.reason = ''
self.shard_id: Optional[int] = shard_id self.shard_id = shard_id
super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}') super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}')
class PrivilegedIntentsRequired(ClientException): class PrivilegedIntentsRequired(ClientException):
"""Exception that's raised when the gateway is requesting privileged intents """Exception that's thrown when the gateway is requesting privileged intents
but they're not ticked in the developer page yet. but they're not ticked in the developer page yet.
Go to https://discord.com/developers/applications/ and enable the intents Go to https://discord.com/developers/applications/ and enable the intents
@@ -247,31 +206,10 @@ class PrivilegedIntentsRequired(ClientException):
The shard ID that got closed if applicable. The shard ID that got closed if applicable.
""" """
def __init__(self, shard_id: Optional[int]): def __init__(self, shard_id):
self.shard_id: Optional[int] = shard_id self.shard_id = shard_id
msg = ( msg = 'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the ' \
'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the ' 'developer portal. It is recommended to go to https://discord.com/developers/applications/ ' \
'developer portal. It is recommended to go to https://discord.com/developers/applications/ ' 'and explicitly enable the privileged intents within your application\'s page. If this is not ' \
'and explicitly enable the privileged intents within your application\'s page. If this is not ' 'possible, then consider disabling the privileged intents instead.'
'possible, then consider disabling the privileged intents instead.'
)
super().__init__(msg % shard_id) super().__init__(msg % shard_id)
class InteractionResponded(ClientException):
"""Exception that's raised when sending another interaction response using
:class:`InteractionResponse` when one has already been done before.
An interaction can only respond once.
.. versionadded:: 2.0
Attributes
-----------
interaction: :class:`Interaction`
The interaction that's already been responded to.
"""
def __init__(self, interaction: Interaction):
self.interaction: Interaction = interaction
super().__init__('This interaction has already been responded to before')

View File

@@ -16,4 +16,3 @@ from .help import *
from .converter import * from .converter import *
from .cooldowns import * from .cooldowns import *
from .cog import * from .cog import *
from .flags import *

View File

@@ -22,26 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from typing import Any, Callable, Coroutine, TYPE_CHECKING, TypeVar, Union
if TYPE_CHECKING:
from .context import Context
from .cog import Cog
from .errors import CommandError
T = TypeVar('T')
Coro = Coroutine[Any, Any, T]
MaybeCoro = Union[T, Coro[T]]
CoroFunc = Callable[..., Coro[Any]]
Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]]
# This is merely a tag type to avoid circular import issues. # This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution. # Yes, this is a terrible solution but ultimately it is the only solution.
class _BaseCommand: class _BaseCommand:

View File

@@ -22,36 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import asyncio import asyncio
import collections import collections
import collections.abc
import inspect
import importlib.util import importlib.util
import inspect
import itertools
import sys import sys
import traceback import traceback
import types import types
from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union
import discord import discord
from .core import GroupMixin
from .view import StringView
from .context import Context
from . import errors from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog from .cog import Cog
from .context import Context
if TYPE_CHECKING: from .core import GroupMixin
import importlib.machinery from .help import DefaultHelpCommand, HelpCommand
from .view import StringView
from discord.message import Message
from ._types import (
Check,
CoroFunc,
)
__all__ = ( __all__ = (
'when_mentioned', 'when_mentioned',
@@ -60,21 +47,14 @@ __all__ = (
'AutoShardedBot', 'AutoShardedBot',
) )
MISSING: Any = discord.utils.MISSING def when_mentioned(bot, msg):
T = TypeVar('T')
CFT = TypeVar('CFT', bound='CoroFunc')
CXT = TypeVar('CXT', bound='Context')
def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]:
"""A callable that implements a command prefix equivalent to being mentioned. """A callable that implements a command prefix equivalent to being mentioned.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
""" """
# bot.user will never be None when this is called return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> ']
return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore
def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: def when_mentioned_or(*prefixes):
"""A callable that implements when mentioned or other prefixes provided. """A callable that implements when mentioned or other prefixes provided.
These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute.
@@ -110,7 +90,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M
return inner return inner
def _is_submodule(parent: str, child: str) -> bool: def _is_submodule(parent, child):
return parent == child or child.startswith(parent + ".") return parent == child or child.startswith(parent + ".")
class _DefaultRepr: class _DefaultRepr:
@@ -120,13 +100,14 @@ class _DefaultRepr:
_default = _DefaultRepr() _default = _DefaultRepr()
class BotBase(GroupMixin): class BotBase(GroupMixin):
def __init__(self, command_prefix, help_command=_default, description=None, **options): def __init__(self, command_prefix, case_insensitive_prefix=False, help_command=_default, description=None, **options):
super().__init__(**options) super().__init__(**options)
self.command_prefix = command_prefix self.command_prefix = command_prefix
self.extra_events: Dict[str, List[CoroFunc]] = {} self.case_insensitive_prefix = case_insensitive_prefix
self.__cogs: Dict[str, Cog] = {} self.extra_events = {}
self.__extensions: Dict[str, types.ModuleType] = {} self.__cogs = {}
self._checks: List[Check] = [] self.__extensions = {}
self._checks = []
self._check_once = [] self._check_once = []
self._before_invoke = None self._before_invoke = None
self._after_invoke = None self._after_invoke = None
@@ -142,22 +123,48 @@ class BotBase(GroupMixin):
if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection): if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection):
raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}') raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}')
if options.pop('self_bot', False):
self._skip_check = lambda x, y: x != y
else:
self._skip_check = lambda x, y: x == y
if help_command is _default: if help_command is _default:
self.help_command = DefaultHelpCommand() self.help_command = DefaultHelpCommand()
else: else:
self.help_command = help_command self.help_command = help_command
@property
def owner(self):
""":class:`discord.User`: The owner, retrieved from owner_id. In case of improper caching, this can return None
.. versionadded:: 1.5.0.1"""
if not self.owner_id or self.owner_ids:
raise AttributeError('No owner_id specified or you used owner_ids. If you used owner_ids, please refer to `Bot.owners`')
return self.get_user(self.owner_id)
@property
def owners(self):
"""List[:class:`discord.User`]: The owners, retrieved from owner_ids. In case of improper caching, this list may not contain all owners.
.. versionadded:: 1.5.0.1"""
if not self.owner_ids or self.owner_id:
raise TypeError('No owner_ids specified or you used owner_id. If you used owner_id, please refer to `Bot.owner`')
owners = []
for user in self.owner_ids:
owner = self.get_user(user)
if owner:
owners.append(owner)
return owners
# internal helpers # internal helpers
def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: def dispatch(self, event_name, *args, **kwargs):
# super() will resolve to Client super().dispatch(event_name, *args, **kwargs)
super().dispatch(event_name, *args, **kwargs) # type: ignore
ev = 'on_' + event_name ev = 'on_' + event_name
for event in self.extra_events.get(ev, []): for event in self.extra_events.get(ev, []):
self._schedule_event(event, ev, *args, **kwargs) # type: ignore self._schedule_event(event, ev, *args, **kwargs)
@discord.utils.copy_doc(discord.Client.close) async def close(self):
async def close(self) -> None:
for extension in tuple(self.__extensions): for extension in tuple(self.__extensions):
try: try:
self.unload_extension(extension) self.unload_extension(extension)
@@ -170,9 +177,9 @@ class BotBase(GroupMixin):
except Exception: except Exception:
pass pass
await super().close() # type: ignore await super().close()
async def on_command_error(self, context: Context, exception: errors.CommandError) -> None: async def on_command_error(self, context, exception):
"""|coro| """|coro|
The default command error handler provided by the bot. The default command error handler provided by the bot.
@@ -185,12 +192,11 @@ class BotBase(GroupMixin):
if self.extra_events.get('on_command_error', None): if self.extra_events.get('on_command_error', None):
return return
command = context.command if hasattr(context.command, 'on_error'):
if command and command.has_error_handler():
return return
cog = context.cog cog = context.cog
if cog and cog.has_error_handler(): if cog and Cog._get_overridden_method(cog.cog_command_error) is not None:
return return
print(f'Ignoring exception in command {context.command}:', file=sys.stderr) print(f'Ignoring exception in command {context.command}:', file=sys.stderr)
@@ -198,7 +204,7 @@ class BotBase(GroupMixin):
# global check registration # global check registration
def check(self, func: T) -> T: def check(self, func):
r"""A decorator that adds a global check to the bot. r"""A decorator that adds a global check to the bot.
A global check is similar to a :func:`.check` that is applied A global check is similar to a :func:`.check` that is applied
@@ -223,11 +229,10 @@ class BotBase(GroupMixin):
return ctx.command.qualified_name in allowed_commands return ctx.command.qualified_name in allowed_commands
""" """
# T was used instead of Check to ensure the type matches on return self.add_check(func)
self.add_check(func) # type: ignore
return func return func
def add_check(self, func: Check, *, call_once: bool = False) -> None: def add_check(self, func, *, call_once=False):
"""Adds a global check to the bot. """Adds a global check to the bot.
This is the non-decorator interface to :meth:`.check` This is the non-decorator interface to :meth:`.check`
@@ -239,7 +244,7 @@ class BotBase(GroupMixin):
The function that was used as a global check. The function that was used as a global check.
call_once: :class:`bool` call_once: :class:`bool`
If the function should only be called once per If the function should only be called once per
:meth:`.invoke` call. :meth:`.Command.invoke` call.
""" """
if call_once: if call_once:
@@ -247,7 +252,7 @@ class BotBase(GroupMixin):
else: else:
self._checks.append(func) self._checks.append(func)
def remove_check(self, func: Check, *, call_once: bool = False) -> None: def remove_check(self, func, *, call_once=False):
"""Removes a global check from the bot. """Removes a global check from the bot.
This function is idempotent and will not raise an exception This function is idempotent and will not raise an exception
@@ -268,11 +273,11 @@ class BotBase(GroupMixin):
except ValueError: except ValueError:
pass pass
def check_once(self, func: CFT) -> CFT: def check_once(self, func):
r"""A decorator that adds a "call once" global check to the bot. r"""A decorator that adds a "call once" global check to the bot.
Unlike regular global checks, this one is called only once Unlike regular global checks, this one is called only once
per :meth:`.invoke` call. per :meth:`.Command.invoke` call.
Regular global checks are called whenever a command is called Regular global checks are called whenever a command is called
or :meth:`.Command.can_run` is called. This type of check or :meth:`.Command.can_run` is called. This type of check
@@ -306,16 +311,15 @@ class BotBase(GroupMixin):
self.add_check(func, call_once=True) self.add_check(func, call_once=True)
return func return func
async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: async def can_run(self, ctx, *, call_once=False):
data = self._check_once if call_once else self._checks data = self._check_once if call_once else self._checks
if len(data) == 0: if len(data) == 0:
return True return True
# type-checker doesn't distinguish between functions and methods return await discord.utils.async_all(f(ctx) for f in data)
return await discord.utils.async_all(f(ctx) for f in data) # type: ignore
async def is_owner(self, user: discord.User) -> bool: async def is_owner(self, user):
"""|coro| """|coro|
Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of Checks if a :class:`~discord.User` or :class:`~discord.Member` is the owner of
@@ -344,8 +348,7 @@ class BotBase(GroupMixin):
elif self.owner_ids: elif self.owner_ids:
return user.id in self.owner_ids return user.id in self.owner_ids
else: else:
app = await self.application_info()
app = await self.application_info() # type: ignore
if app.team: if app.team:
self.owner_ids = ids = {m.id for m in app.team.members} self.owner_ids = ids = {m.id for m in app.team.members}
return user.id in ids return user.id in ids
@@ -353,7 +356,7 @@ class BotBase(GroupMixin):
self.owner_id = owner_id = app.owner.id self.owner_id = owner_id = app.owner.id
return user.id == owner_id return user.id == owner_id
def before_invoke(self, coro: CFT) -> CFT: def before_invoke(self, coro):
"""A decorator that registers a coroutine as a pre-invoke hook. """A decorator that registers a coroutine as a pre-invoke hook.
A pre-invoke hook is called directly before the command is A pre-invoke hook is called directly before the command is
@@ -385,7 +388,7 @@ class BotBase(GroupMixin):
self._before_invoke = coro self._before_invoke = coro
return coro return coro
def after_invoke(self, coro: CFT) -> CFT: def after_invoke(self, coro):
r"""A decorator that registers a coroutine as a post-invoke hook. r"""A decorator that registers a coroutine as a post-invoke hook.
A post-invoke hook is called directly after the command is A post-invoke hook is called directly after the command is
@@ -420,14 +423,14 @@ class BotBase(GroupMixin):
# listener registration # listener registration
def add_listener(self, func: CoroFunc, name: str = MISSING) -> None: def add_listener(self, func, name=None):
"""The non decorator alternative to :meth:`.listen`. """The non decorator alternative to :meth:`.listen`.
Parameters Parameters
----------- -----------
func: :ref:`coroutine <coroutine>` func: :ref:`coroutine <coroutine>`
The function to call. The function to call.
name: :class:`str` name: Optional[:class:`str`]
The name of the event to listen for. Defaults to ``func.__name__``. The name of the event to listen for. Defaults to ``func.__name__``.
Example Example
@@ -442,7 +445,7 @@ class BotBase(GroupMixin):
bot.add_listener(my_message, 'on_message') bot.add_listener(my_message, 'on_message')
""" """
name = func.__name__ if name is MISSING else name name = func.__name__ if name is None else name
if not asyncio.iscoroutinefunction(func): if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be coroutines') raise TypeError('Listeners must be coroutines')
@@ -452,7 +455,7 @@ class BotBase(GroupMixin):
else: else:
self.extra_events[name] = [func] self.extra_events[name] = [func]
def remove_listener(self, func: CoroFunc, name: str = MISSING) -> None: def remove_listener(self, func, name=None):
"""Removes a listener from the pool of listeners. """Removes a listener from the pool of listeners.
Parameters Parameters
@@ -464,7 +467,7 @@ class BotBase(GroupMixin):
``func.__name__``. ``func.__name__``.
""" """
name = func.__name__ if name is MISSING else name name = func.__name__ if name is None else name
if name in self.extra_events: if name in self.extra_events:
try: try:
@@ -472,7 +475,7 @@ class BotBase(GroupMixin):
except ValueError: except ValueError:
pass pass
def listen(self, name: str = MISSING) -> Callable[[CFT], CFT]: def listen(self, name=None):
"""A decorator that registers another function as an external """A decorator that registers another function as an external
event listener. Basically this allows you to listen to multiple event listener. Basically this allows you to listen to multiple
events from different places e.g. such as :func:`.on_ready` events from different places e.g. such as :func:`.on_ready`
@@ -502,7 +505,7 @@ class BotBase(GroupMixin):
The function being listened to is not a coroutine. The function being listened to is not a coroutine.
""" """
def decorator(func: CFT) -> CFT: def decorator(func):
self.add_listener(func, name) self.add_listener(func, name)
return func return func
@@ -510,25 +513,15 @@ class BotBase(GroupMixin):
# cogs # cogs
def add_cog(self, cog: Cog, *, override: bool = False) -> None: def add_cog(self, cog):
"""Adds a "cog" to the bot. """Adds a "cog" to the bot.
A cog is a class that has its own event listeners and commands. A cog is a class that has its own event listeners and commands.
.. versionchanged:: 2.0
:exc:`.ClientException` is raised when a cog with the same name
is already loaded.
Parameters Parameters
----------- -----------
cog: :class:`.Cog` cog: :class:`.Cog`
The cog to register to the bot. The cog to register to the bot.
override: :class:`bool`
If a previously loaded cog with the same name should be ejected
instead of raising an error.
.. versionadded:: 2.0
Raises Raises
------- -------
@@ -536,25 +529,18 @@ class BotBase(GroupMixin):
The cog does not inherit from :class:`.Cog`. The cog does not inherit from :class:`.Cog`.
CommandError CommandError
An error happened during loading. An error happened during loading.
.ClientException
A cog with the same name is already loaded.
""" """
if not isinstance(cog, Cog): if not isinstance(cog, Cog):
raise TypeError('cogs must derive from Cog') raise TypeError('cogs must derive from Cog')
cog_name = cog.__cog_name__
existing = self.__cogs.get(cog_name)
if existing is not None:
if not override:
raise discord.ClientException(f'Cog named {cog_name!r} already loaded')
self.remove_cog(cog_name)
cog = cog._inject(self) cog = cog._inject(self)
self.__cogs[cog_name] = cog self.__cogs[cog.__cog_name__] = cog
if cog.aliases:
for alias in cog.aliases:
self.__cogs[alias] = cog
def get_cog(self, name: str) -> Optional[Cog]: def get_cog(self, name):
"""Gets the cog instance requested. """Gets the cog instance requested.
If the cog is not found, ``None`` is returned instead. If the cog is not found, ``None`` is returned instead.
@@ -573,8 +559,8 @@ class BotBase(GroupMixin):
""" """
return self.__cogs.get(name) return self.__cogs.get(name)
def remove_cog(self, name: str) -> Optional[Cog]: def remove_cog(self, name):
"""Removes a cog from the bot and returns it. """Removes a cog from the bot.
All registered commands and event listeners that the All registered commands and event listeners that the
cog has registered will be removed as well. cog has registered will be removed as well.
@@ -585,32 +571,29 @@ class BotBase(GroupMixin):
----------- -----------
name: :class:`str` name: :class:`str`
The name of the cog to remove. The name of the cog to remove.
Returns
-------
Optional[:class:`.Cog`]
The cog that was removed. ``None`` if not found.
""" """
cog = self.__cogs.pop(name, None) cog = self.__cogs.pop(name, None)
if cog is None: if cog is None:
return return
if cog.aliases:
for alias in cog.aliases:
self.__cogs.pop(alias)
help_command = self._help_command help_command = self._help_command
if help_command and help_command.cog is cog: if help_command and help_command.cog is cog:
help_command.cog = None help_command.cog = None
cog._eject(self) cog._eject(self)
return cog
@property @property
def cogs(self) -> Mapping[str, Cog]: def cogs(self):
"""Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog.""" """Mapping[:class:`str`, :class:`Cog`]: A read-only mapping of cog name to cog."""
return types.MappingProxyType(self.__cogs) return types.MappingProxyType(self.__cogs)
# extensions # extensions
def _remove_module_references(self, name: str) -> None: def _remove_module_references(self, name):
# find all references to the module # find all references to the module
# remove the cogs registered from the module # remove the cogs registered from the module
for cogname, cog in self.__cogs.copy().items(): for cogname, cog in self.__cogs.copy().items():
@@ -634,7 +617,7 @@ class BotBase(GroupMixin):
for index in reversed(remove): for index in reversed(remove):
del event_list[index] del event_list[index]
def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: def _call_module_finalizers(self, lib, key):
try: try:
func = getattr(lib, 'teardown') func = getattr(lib, 'teardown')
except AttributeError: except AttributeError:
@@ -652,12 +635,12 @@ class BotBase(GroupMixin):
if _is_submodule(name, module): if _is_submodule(name, module):
del sys.modules[module] del sys.modules[module]
def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: def _load_from_module_spec(self, spec, key):
# precondition: key not in self.__extensions # precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec) lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib sys.modules[key] = lib
try: try:
spec.loader.exec_module(lib) # type: ignore spec.loader.exec_module(lib)
except Exception as e: except Exception as e:
del sys.modules[key] del sys.modules[key]
raise errors.ExtensionFailed(key, e) from e raise errors.ExtensionFailed(key, e) from e
@@ -678,13 +661,13 @@ class BotBase(GroupMixin):
else: else:
self.__extensions[key] = lib self.__extensions[key] = lib
def _resolve_name(self, name: str, package: Optional[str]) -> str: def _resolve_name(self, name, package):
try: try:
return importlib.util.resolve_name(name, package) return importlib.util.resolve_name(name, package)
except ImportError: except ImportError:
raise errors.ExtensionNotFound(name) raise errors.ExtensionNotFound(name)
def load_extension(self, name: str, *, package: Optional[str] = None) -> None: def load_extension(self, name, *, package=None):
"""Loads an extension. """Loads an extension.
An extension is a python module that contains commands, cogs, or An extension is a python module that contains commands, cogs, or
@@ -731,7 +714,7 @@ class BotBase(GroupMixin):
self._load_from_module_spec(spec, name) self._load_from_module_spec(spec, name)
def unload_extension(self, name: str, *, package: Optional[str] = None) -> None: def unload_extension(self, name, *, package=None):
"""Unloads an extension. """Unloads an extension.
When the extension is unloaded, all commands, listeners, and cogs are When the extension is unloaded, all commands, listeners, and cogs are
@@ -772,7 +755,7 @@ class BotBase(GroupMixin):
self._remove_module_references(lib.__name__) self._remove_module_references(lib.__name__)
self._call_module_finalizers(lib, name) self._call_module_finalizers(lib, name)
def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: def reload_extension(self, name, *, package=None):
"""Atomically reloads an extension. """Atomically reloads an extension.
This replaces the extension with the same extension, only refreshed. This is This replaces the extension with the same extension, only refreshed. This is
@@ -828,7 +811,7 @@ class BotBase(GroupMixin):
# if the load failed, the remnants should have been # if the load failed, the remnants should have been
# cleaned from the load_extension function call # cleaned from the load_extension function call
# so let's load it from our old compiled library. # so let's load it from our old compiled library.
lib.setup(self) # type: ignore lib.setup(self)
self.__extensions[name] = lib self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller # revert sys.modules back to normal and raise back to caller
@@ -836,18 +819,18 @@ class BotBase(GroupMixin):
raise raise
@property @property
def extensions(self) -> Mapping[str, types.ModuleType]: def extensions(self):
"""Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension.""" """Mapping[:class:`str`, :class:`py:types.ModuleType`]: A read-only mapping of extension name to extension."""
return types.MappingProxyType(self.__extensions) return types.MappingProxyType(self.__extensions)
# help command stuff # help command stuff
@property @property
def help_command(self) -> Optional[HelpCommand]: def help_command(self):
return self._help_command return self._help_command
@help_command.setter @help_command.setter
def help_command(self, value: Optional[HelpCommand]) -> None: def help_command(self, value):
if value is not None: if value is not None:
if not isinstance(value, HelpCommand): if not isinstance(value, HelpCommand):
raise TypeError('help_command must be a subclass of HelpCommand') raise TypeError('help_command must be a subclass of HelpCommand')
@@ -863,7 +846,7 @@ class BotBase(GroupMixin):
# command processing # command processing
async def get_prefix(self, message: Message) -> Union[List[str], str]: async def get_prefix(self, message):
"""|coro| """|coro|
Retrieves the prefix the bot is listening to Retrieves the prefix the bot is listening to
@@ -894,14 +877,24 @@ class BotBase(GroupMixin):
raise raise
raise TypeError("command_prefix must be plain string, iterable of strings, or callable " raise TypeError("command_prefix must be plain string, iterable of strings, or callable "
f"returning either of these, not {ret.__class__.__name__}") "returning either of these, not {}".format(ret.__class__.__name__))
if not ret: if not ret:
raise ValueError("Iterable command_prefix must contain at least one prefix") raise ValueError("Iterable command_prefix must contain at least one prefix")
# if self.case_insensitive_prefix:
# if isinstance(ret, list):
# temp = []
# for pre in ret:
# if pre in (self.user.mention + ' ', '<@!%s> ' % self.user.id):
# continue
# temp += list(map(''.join, itertools.product(*((c.upper(), c.lower()) for c in pre))))
# ret = temp
# else:
# ret = list(map(''.join, itertools.product(*((c.upper(), c.lower()) for c in ret))))
return ret return ret
async def get_context(self, message: Message, *, cls: Type[CXT] = Context) -> CXT: async def get_context(self, message, *, cls=Context):
r"""|coro| r"""|coro|
Returns the invocation context from the message. Returns the invocation context from the message.
@@ -934,7 +927,7 @@ class BotBase(GroupMixin):
view = StringView(message.content) view = StringView(message.content)
ctx = cls(prefix=None, view=view, bot=self, message=message) ctx = cls(prefix=None, view=view, bot=self, message=message)
if message.author.id == self.user.id: # type: ignore if self._skip_check(message.author.id, self.user.id):
return ctx return ctx
prefix = await self.get_prefix(message) prefix = await self.get_prefix(message)
@@ -955,13 +948,13 @@ class BotBase(GroupMixin):
except TypeError: except TypeError:
if not isinstance(prefix, list): if not isinstance(prefix, list):
raise TypeError("get_prefix must return either a string or a list of string, " raise TypeError("get_prefix must return either a string or a list of string, "
f"not {prefix.__class__.__name__}") "not {}".format(prefix.__class__.__name__))
# It's possible a bad command_prefix got us here. # It's possible a bad command_prefix got us here.
for value in prefix: for value in prefix:
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError("Iterable command_prefix or list returned from get_prefix must " raise TypeError("Iterable command_prefix or list returned from get_prefix must "
f"contain only strings, not {value.__class__.__name__}") "contain only strings, not {}".format(value.__class__.__name__))
# Getting here shouldn't happen # Getting here shouldn't happen
raise raise
@@ -971,12 +964,11 @@ class BotBase(GroupMixin):
invoker = view.get_word() invoker = view.get_word()
ctx.invoked_with = invoker ctx.invoked_with = invoker
# type-checker fails to narrow invoked_prefix type. ctx.prefix = invoked_prefix
ctx.prefix = invoked_prefix # type: ignore
ctx.command = self.all_commands.get(invoker) ctx.command = self.all_commands.get(invoker)
return ctx return ctx
async def invoke(self, ctx: Context) -> None: async def invoke(self, ctx):
"""|coro| """|coro|
Invokes the command given under the invocation context and Invokes the command given under the invocation context and
@@ -1002,7 +994,7 @@ class BotBase(GroupMixin):
exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found')
self.dispatch('command_error', ctx, exc) self.dispatch('command_error', ctx, exc)
async def process_commands(self, message: Message) -> None: async def process_commands(self, message):
"""|coro| """|coro|
This function processes the commands that have been registered This function processes the commands that have been registered
@@ -1074,12 +1066,20 @@ class Bot(BotBase, discord.Client):
matches messages starting with ``!?``. This is especially important matches messages starting with ``!?``. This is especially important
when passing an empty string, it should always be last as no prefix when passing an empty string, it should always be last as no prefix
after it will be matched. after it will be matched.
case_insensitive_prefix: :class:`bool`
Wheter the provided command_prefix should be case insensitive or not
.. versionadded:: 1.6.0.7
case_insensitive: :class:`bool` case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``. This Whether the commands should be case insensitive. Defaults to ``False``. This
attribute does not carry over to groups. You must set it to every group if attribute does not carry over to groups. You must set it to every group if
you require group commands to be case insensitive as well. you require group commands to be case insensitive as well.
description: :class:`str` description: :class:`str`
The content prefixed into the default help message. The content prefixed into the default help message.
self_bot: :class:`bool`
If ``True``, the bot will only listen to commands invoked by itself rather
than ignoring itself. If ``False`` (the default) then the bot will ignore
itself. This cannot be changed once initialised.
help_command: Optional[:class:`.HelpCommand`] help_command: Optional[:class:`.HelpCommand`]
The help command implementation to use. This can be dynamically The help command implementation to use. This can be dynamically
set at runtime. To remove the help command pass ``None``. For more set at runtime. To remove the help command pass ``None``. For more

View File

@@ -21,30 +21,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import inspect import inspect
import discord.utils import copy
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type
from ._types import _BaseCommand from ._types import _BaseCommand
if TYPE_CHECKING:
from .bot import BotBase
from .context import Context
from .core import Command
__all__ = ( __all__ = (
'CogMeta', 'CogMeta',
'Cog', 'Cog',
) )
CogT = TypeVar('CogT', bound='Cog')
FuncT = TypeVar('FuncT', bound=Callable[..., Any])
MISSING: Any = discord.utils.MISSING
class CogMeta(type): class CogMeta(type):
"""A metaclass for defining a cog. """A metaclass for defining a cog.
@@ -103,17 +89,23 @@ class CogMeta(type):
@commands.command(hidden=False) @commands.command(hidden=False)
async def bar(self, ctx): async def bar(self, ctx):
pass # hidden -> False pass # hidden -> False
"""
__cog_name__: str
__cog_settings__: Dict[str, Any]
__cog_commands__: List[Command]
__cog_listeners__: List[Tuple[str, str]]
def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: aliases: :class:`list`
A list of aliases for the cog name.
.. versionadded:: 1.6.0.7
"""
def __new__(cls, *args, **kwargs):
name, bases, attrs = args name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name) attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
aliases = kwargs.pop('aliases', [])
if not isinstance(aliases, list):
raise TypeError("Cog aliases must be a list, not a {0}".format(type(aliases)))
attrs['aliases'] = aliases
description = kwargs.pop('description', None) description = kwargs.pop('description', None)
if description is None: if description is None:
description = inspect.cleandoc(attrs.get('__doc__', '')) description = inspect.cleandoc(attrs.get('__doc__', ''))
@@ -162,14 +154,14 @@ class CogMeta(type):
new_cls.__cog_listeners__ = listeners_as_list new_cls.__cog_listeners__ = listeners_as_list
return new_cls return new_cls
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args, **kwargs):
super().__init__(*args) super().__init__(*args)
@classmethod @classmethod
def qualified_name(cls) -> str: def qualified_name(cls):
return cls.__cog_name__ return cls.__cog_name__
def _cog_special_method(func: FuncT) -> FuncT: def _cog_special_method(func):
func.__cog_special_method__ = None func.__cog_special_method__ = None
return func return func
@@ -183,12 +175,8 @@ class Cog(metaclass=CogMeta):
When inheriting from this class, the options shown in :class:`CogMeta` When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here. are equally valid here.
""" """
__cog_name__: ClassVar[str]
__cog_settings__: ClassVar[Dict[str, Any]]
__cog_commands__: ClassVar[List[Command]]
__cog_listeners__: ClassVar[List[Tuple[str, str]]]
def __new__(cls: Type[CogT], *args: Any, **kwargs: Any) -> CogT: def __new__(cls, *args, **kwargs):
# For issue 426, we need to store a copy of the command objects # For issue 426, we need to store a copy of the command objects
# since we modify them to inject `self` to them. # since we modify them to inject `self` to them.
# To do this, we need to interfere with the Cog creation process. # To do this, we need to interfere with the Cog creation process.
@@ -196,8 +184,7 @@ class Cog(metaclass=CogMeta):
cmd_attrs = cls.__cog_settings__ cmd_attrs = cls.__cog_settings__
# Either update the command with the cog provided defaults or copy it. # Either update the command with the cog provided defaults or copy it.
# r.e type ignore, type-checker complains about overriding a ClassVar self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__)
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore
lookup = { lookup = {
cmd.qualified_name: cmd cmd.qualified_name: cmd
@@ -210,15 +197,15 @@ class Cog(metaclass=CogMeta):
parent = command.parent parent = command.parent
if parent is not None: if parent is not None:
# Get the latest parent reference # Get the latest parent reference
parent = lookup[parent.qualified_name] # type: ignore parent = lookup[parent.qualified_name]
# Update our parent's reference to our self # Update our parent's reference to our self
parent.remove_command(command.name) # type: ignore parent.remove_command(command.name)
parent.add_command(command) # type: ignore parent.add_command(command)
return self return self
def get_commands(self) -> List[Command]: def get_commands(self):
r""" r"""
Returns Returns
-------- --------
@@ -233,20 +220,20 @@ class Cog(metaclass=CogMeta):
return [c for c in self.__cog_commands__ if c.parent is None] return [c for c in self.__cog_commands__ if c.parent is None]
@property @property
def qualified_name(self) -> str: def qualified_name(self):
""":class:`str`: Returns the cog's specified name, not the class name.""" """:class:`str`: Returns the cog's specified name, not the class name."""
return self.__cog_name__ return self.__cog_name__
@property @property
def description(self) -> str: def description(self):
""":class:`str`: Returns the cog's description, typically the cleaned docstring.""" """:class:`str`: Returns the cog's description, typically the cleaned docstring."""
return self.__cog_description__ return self.__cog_description__
@description.setter @description.setter
def description(self, description: str) -> None: def description(self, description):
self.__cog_description__ = description self.__cog_description__ = description
def walk_commands(self) -> Generator[Command, None, None]: def walk_commands(self):
"""An iterator that recursively walks through this cog's commands and subcommands. """An iterator that recursively walks through this cog's commands and subcommands.
Yields Yields
@@ -261,7 +248,7 @@ class Cog(metaclass=CogMeta):
if isinstance(command, GroupMixin): if isinstance(command, GroupMixin):
yield from command.walk_commands() yield from command.walk_commands()
def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]: def get_listeners(self):
"""Returns a :class:`list` of (name, function) listener pairs that are defined in this cog. """Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.
Returns Returns
@@ -272,12 +259,12 @@ class Cog(metaclass=CogMeta):
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__] return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
@classmethod @classmethod
def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]: def _get_overridden_method(cls, method):
"""Return None if the method is not overridden. Otherwise returns the overridden method.""" """Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method) return getattr(method.__func__, '__cog_special_method__', method)
@classmethod @classmethod
def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: def listener(cls, name=None):
"""A decorator that marks a function as a listener. """A decorator that marks a function as a listener.
This is the cog equivalent of :meth:`.Bot.listen`. This is the cog equivalent of :meth:`.Bot.listen`.
@@ -295,10 +282,10 @@ class Cog(metaclass=CogMeta):
the name. the name.
""" """
if name is not MISSING and not isinstance(name, str): if name is not None and not isinstance(name, str):
raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.') raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.')
def decorator(func: FuncT) -> FuncT: def decorator(func):
actual = func actual = func
if isinstance(actual, staticmethod): if isinstance(actual, staticmethod):
actual = actual.__func__ actual = actual.__func__
@@ -317,7 +304,7 @@ class Cog(metaclass=CogMeta):
return func return func
return decorator return decorator
def has_error_handler(self) -> bool: def has_error_handler(self):
""":class:`bool`: Checks whether the cog has an error handler. """:class:`bool`: Checks whether the cog has an error handler.
.. versionadded:: 1.7 .. versionadded:: 1.7
@@ -325,7 +312,7 @@ class Cog(metaclass=CogMeta):
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__') return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
@_cog_special_method @_cog_special_method
def cog_unload(self) -> None: def cog_unload(self):
"""A special method that is called when the cog gets removed. """A special method that is called when the cog gets removed.
This function **cannot** be a coroutine. It must be a regular This function **cannot** be a coroutine. It must be a regular
@@ -336,7 +323,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
def bot_check_once(self, ctx: Context) -> bool: def bot_check_once(self, ctx):
"""A special method that registers as a :meth:`.Bot.check_once` """A special method that registers as a :meth:`.Bot.check_once`
check. check.
@@ -346,7 +333,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
def bot_check(self, ctx: Context) -> bool: def bot_check(self, ctx):
"""A special method that registers as a :meth:`.Bot.check` """A special method that registers as a :meth:`.Bot.check`
check. check.
@@ -356,8 +343,8 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
def cog_check(self, ctx: Context) -> bool: def cog_check(self, ctx):
"""A special method that registers as a :func:`~discord.ext.commands.check` """A special method that registers as a :func:`commands.check`
for every command and subcommand in this cog. for every command and subcommand in this cog.
This function **can** be a coroutine and must take a sole parameter, This function **can** be a coroutine and must take a sole parameter,
@@ -366,7 +353,7 @@ class Cog(metaclass=CogMeta):
return True return True
@_cog_special_method @_cog_special_method
async def cog_command_error(self, ctx: Context, error: Exception) -> None: async def cog_command_error(self, ctx, error):
"""A special method that is called whenever an error """A special method that is called whenever an error
is dispatched inside this cog. is dispatched inside this cog.
@@ -385,7 +372,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
async def cog_before_invoke(self, ctx: Context) -> None: async def cog_before_invoke(self, ctx):
"""A special method that acts as a cog local pre-invoke hook. """A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`. This is similar to :meth:`.Command.before_invoke`.
@@ -400,7 +387,7 @@ class Cog(metaclass=CogMeta):
pass pass
@_cog_special_method @_cog_special_method
async def cog_after_invoke(self, ctx: Context) -> None: async def cog_after_invoke(self, ctx):
"""A special method that acts as a cog local post-invoke hook. """A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`. This is similar to :meth:`.Command.after_invoke`.
@@ -414,7 +401,7 @@ class Cog(metaclass=CogMeta):
""" """
pass pass
def _inject(self: CogT, bot: BotBase) -> CogT: def _inject(self, bot):
cls = self.__class__ cls = self.__class__
# realistically, the only thing that can cause loading errors # realistically, the only thing that can cause loading errors
@@ -449,7 +436,7 @@ class Cog(metaclass=CogMeta):
return self return self
def _eject(self, bot: BotBase) -> None: def _eject(self, bot):
cls = self.__class__ cls = self.__class__
try: try:

View File

@@ -21,52 +21,16 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import inspect
import re import re
from typing import Any, Dict, Generic, List, Optional, TYPE_CHECKING, TypeVar, Union
import discord.abc import discord.abc
import discord.utils import discord.utils
from discord.message import Message
if TYPE_CHECKING:
from typing_extensions import ParamSpec
from discord.abc import MessageableChannel
from discord.guild import Guild
from discord.member import Member
from discord.state import ConnectionState
from discord.user import ClientUser, User
from discord.voice_client import VoiceProtocol
from .bot import Bot, AutoShardedBot
from .cog import Cog
from .core import Command
from .help import HelpCommand
from .view import StringView
__all__ = ( __all__ = (
'Context', 'Context',
) )
MISSING: Any = discord.utils.MISSING class Context(discord.abc.Messageable):
T = TypeVar('T')
BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]")
CogT = TypeVar('CogT', bound="Cog")
if TYPE_CHECKING:
P = ParamSpec('P')
else:
P = TypeVar('P')
class Context(discord.abc.Messageable, Generic[BotT]):
r"""Represents the context in which a command is being invoked under. r"""Represents the context in which a command is being invoked under.
This class contains a lot of meta data to help you understand more about This class contains a lot of meta data to help you understand more about
@@ -83,22 +47,17 @@ class Context(discord.abc.Messageable, Generic[BotT]):
The bot that contains the command being executed. The bot that contains the command being executed.
args: :class:`list` args: :class:`list`
The list of transformed arguments that were passed into the command. The list of transformed arguments that were passed into the command.
If this is accessed during the :func:`.on_command_error` event If this is accessed during the :func:`on_command_error` event
then this list could be incomplete. then this list could be incomplete.
kwargs: :class:`dict` kwargs: :class:`dict`
A dictionary of transformed arguments that were passed into the command. A dictionary of transformed arguments that were passed into the command.
Similar to :attr:`args`\, if this is accessed in the Similar to :attr:`args`\, if this is accessed in the
:func:`.on_command_error` event then this dict could be incomplete. :func:`on_command_error` event then this dict could be incomplete.
current_parameter: Optional[:class:`inspect.Parameter`] prefix: :class:`str`
The parameter that is currently being inspected and converted.
This is only of use for within converters.
.. versionadded:: 2.0
prefix: Optional[:class:`str`]
The prefix that was used to invoke the command. The prefix that was used to invoke the command.
command: Optional[:class:`Command`] command: :class:`Command`
The command that is being invoked currently. The command that is being invoked currently.
invoked_with: Optional[:class:`str`] invoked_with: :class:`str`
The command name that triggered this invocation. Useful for finding out The command name that triggered this invocation. Useful for finding out
which alias called the command. which alias called the command.
invoked_parents: List[:class:`str`] invoked_parents: List[:class:`str`]
@@ -109,7 +68,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
.. versionadded:: 1.7 .. versionadded:: 1.7
invoked_subcommand: Optional[:class:`Command`] invoked_subcommand: :class:`Command`
The subcommand that was invoked. The subcommand that was invoked.
If no valid subcommand was invoked then this is equal to ``None``. If no valid subcommand was invoked then this is equal to ``None``.
subcommand_passed: Optional[:class:`str`] subcommand_passed: Optional[:class:`str`]
@@ -122,38 +81,31 @@ class Context(discord.abc.Messageable, Generic[BotT]):
or invoked. or invoked.
""" """
def __init__(self, def __init__(self, **attrs):
*, self.message = attrs.pop('message', None)
message: Message, self.bot = attrs.pop('bot', None)
bot: BotT, self.args = attrs.pop('args', [])
view: StringView, self.kwargs = attrs.pop('kwargs', {})
args: List[Any] = MISSING, self.prefix = attrs.pop('prefix')
kwargs: Dict[str, Any] = MISSING, self.command = attrs.pop('command', None)
prefix: Optional[str] = None, self.view = attrs.pop('view', None)
command: Optional[Command] = None, self.invoked_with = attrs.pop('invoked_with', None)
invoked_with: Optional[str] = None, self.invoked_parents = attrs.pop('invoked_parents', [])
invoked_parents: List[str] = MISSING, self.invoked_subcommand = attrs.pop('invoked_subcommand', None)
invoked_subcommand: Optional[Command] = None, self.subcommand_passed = attrs.pop('subcommand_passed', None)
subcommand_passed: Optional[str] = None, self.command_failed = attrs.pop('command_failed', False)
command_failed: bool = False, self._state = self.message._state
current_parameter: Optional[inspect.Parameter] = None,
):
self.message: Message = message
self.bot: BotT = bot
self.args: List[Any] = args or []
self.kwargs: Dict[str, Any] = kwargs or {}
self.prefix: Optional[str] = prefix
self.command: Optional[Command] = command
self.view: StringView = view
self.invoked_with: Optional[str] = invoked_with
self.invoked_parents: List[str] = invoked_parents or []
self.invoked_subcommand: Optional[Command] = invoked_subcommand
self.subcommand_passed: Optional[str] = subcommand_passed
self.command_failed: bool = command_failed
self.current_parameter: Optional[inspect.Parameter] = current_parameter
self._state: ConnectionState = self.message._state
async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: @property
def clean_prefix(self):
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
.. versionadded:: 1.5.1.4"""
user = self.guild.me if self.guild else self.bot.user
pattern = re.compile(r"<@!?%s>" % user.id)
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
async def invoke(self, command, /, *args, **kwargs):
r"""|coro| r"""|coro|
Calls a command with the arguments given. Calls a command with the arguments given.
@@ -175,7 +127,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
command: :class:`.Command` command: :class:`.Command`
The command that is going to be called. The command that is going to be called.
\*args \*args
The arguments to use. The arguments to to use.
\*\*kwargs \*\*kwargs
The keyword arguments to use. The keyword arguments to use.
@@ -184,9 +136,17 @@ class Context(discord.abc.Messageable, Generic[BotT]):
TypeError TypeError
The command argument to invoke is missing. The command argument to invoke is missing.
""" """
return await command(self, *args, **kwargs) arguments = []
if command.cog is not None:
arguments.append(command.cog)
async def reinvoke(self, *, call_hooks: bool = False, restart: bool = True) -> None: arguments.append(self)
arguments.extend(args)
ret = await command.callback(*arguments, **kwargs)
return ret
async def reinvoke(self, *, call_hooks=False, restart=True):
"""|coro| """|coro|
Calls the command again. Calls the command again.
@@ -230,7 +190,7 @@ class Context(discord.abc.Messageable, Generic[BotT]):
if restart: if restart:
to_call = cmd.root_parent or cmd to_call = cmd.root_parent or cmd
view.index = len(self.prefix or '') view.index = len(self.prefix)
view.previous = 0 view.previous = 0
self.invoked_parents = [] self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command self.invoked_with = view.get_word() # advance to get the root command
@@ -249,32 +209,15 @@ class Context(discord.abc.Messageable, Generic[BotT]):
self.subcommand_passed = subcommand_passed self.subcommand_passed = subcommand_passed
@property @property
def valid(self) -> bool: def valid(self):
""":class:`bool`: Checks if the invocation context is valid to be invoked with.""" """:class:`bool`: Checks if the invocation context is valid to be invoked with."""
return self.prefix is not None and self.command is not None return self.prefix is not None and self.command is not None
async def _get_channel(self) -> discord.abc.Messageable: async def _get_channel(self):
return self.channel return self.channel
@property @property
def clean_prefix(self) -> str: def cog(self):
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.
.. versionadded:: 2.0
"""
if self.prefix is None:
return ''
user = self.me
# this breaks if the prefix mention is not the bot itself but I
# consider this to be an *incredibly* strange use case. I'd rather go
# for this common use case rather than waste performance for the
# odd one.
pattern = re.compile(r"<@!?%s>" % user.id)
return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix)
@property
def cog(self) -> Optional[Cog]:
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist.""" """Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist."""
if self.command is None: if self.command is None:
@@ -282,39 +225,38 @@ class Context(discord.abc.Messageable, Generic[BotT]):
return self.command.cog return self.command.cog
@discord.utils.cached_property @discord.utils.cached_property
def guild(self) -> Optional[Guild]: def guild(self):
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available.""" """Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
return self.message.guild return self.message.guild
@discord.utils.cached_property @discord.utils.cached_property
def channel(self) -> MessageableChannel: def channel(self):
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command. """Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
Shorthand for :attr:`.Message.channel`. Shorthand for :attr:`.Message.channel`.
""" """
return self.message.channel return self.message.channel
@discord.utils.cached_property @discord.utils.cached_property
def author(self) -> Union[User, Member]: def author(self):
"""Union[:class:`~discord.User`, :class:`.Member`]: """Union[:class:`~discord.User`, :class:`.Member`]:
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author` Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
""" """
return self.message.author return self.message.author
@discord.utils.cached_property @discord.utils.cached_property
def me(self) -> Union[Member, ClientUser]: def me(self):
"""Union[:class:`.Member`, :class:`.ClientUser`]: """Union[:class:`.Member`, :class:`.ClientUser`]:
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts. Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
""" """
# bot.user will never be None at this point. return self.guild.me if self.guild is not None else self.bot.user
return self.guild.me if self.guild is not None else self.bot.user # type: ignore
@property @property
def voice_client(self) -> Optional[VoiceProtocol]: def voice_client(self):
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
g = self.guild g = self.guild
return g.voice_client if g else None return g.voice_client if g else None
async def send_help(self, *args: Any) -> Any: async def send_help(self, *args):
"""send_help(entity=<bot>) """send_help(entity=<bot>)
|coro| |coro|
@@ -366,12 +308,12 @@ class Context(discord.abc.Messageable, Generic[BotT]):
return None return None
entity = args[0] entity = args[0]
if isinstance(entity, str):
entity = bot.get_cog(entity) or bot.get_command(entity)
if entity is None: if entity is None:
return None return None
if isinstance(entity, str):
entity = bot.get_cog(entity) or bot.get_command(entity)
try: try:
entity.qualified_name entity.qualified_name
except AttributeError: except AttributeError:
@@ -395,6 +337,6 @@ class Context(discord.abc.Messageable, Generic[BotT]):
except CommandError as e: except CommandError as e:
await cmd.on_help_command_error(self, e) await cmd.on_help_command_error(self, e)
@discord.utils.copy_doc(Message.reply) @discord.utils.copy_doc(discord.Message.reply)
async def reply(self, content: Optional[str] = None, **kwargs: Any) -> Message: async def reply(self, content=None, **kwargs):
return await self.message.reply(content, **kwargs) return await self.message.reply(content, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -22,10 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Any, Callable, Deque, Dict, Optional, Type, TypeVar, TYPE_CHECKING
from discord.enums import Enum from discord.enums import Enum
import time import time
import asyncio import asyncio
@@ -34,20 +30,13 @@ from collections import deque
from ...abc import PrivateChannel from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached from .errors import MaxConcurrencyReached
if TYPE_CHECKING:
from ...message import Message
__all__ = ( __all__ = (
'BucketType', 'BucketType',
'Cooldown', 'Cooldown',
'CooldownMapping', 'CooldownMapping',
'DynamicCooldownMapping',
'MaxConcurrency', 'MaxConcurrency',
) )
C = TypeVar('C', bound='CooldownMapping')
MC = TypeVar('MC', bound='MaxConcurrency')
class BucketType(Enum): class BucketType(Enum):
default = 0 default = 0
user = 1 user = 1
@@ -57,7 +46,7 @@ class BucketType(Enum):
category = 5 category = 5
role = 6 role = 6
def get_key(self, msg: Message) -> Any: def get_key(self, msg):
if self is BucketType.user: if self is BucketType.user:
return msg.author.id return msg.author.id
elif self is BucketType.guild: elif self is BucketType.guild:
@@ -67,52 +56,33 @@ class BucketType(Enum):
elif self is BucketType.member: elif self is BucketType.member:
return ((msg.guild and msg.guild.id), msg.author.id) return ((msg.guild and msg.guild.id), msg.author.id)
elif self is BucketType.category: elif self is BucketType.category:
return (msg.channel.category or msg.channel).id # type: ignore return (msg.channel.category or msg.channel).id
elif self is BucketType.role: elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds # we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role # and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are # NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do # recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id # type: ignore return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
def __call__(self, msg: Message) -> Any: def __call__(self, msg):
return self.get_key(msg) return self.get_key(msg)
class Cooldown: class Cooldown:
"""Represents a cooldown for a command. __slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last')
Attributes def __init__(self, rate, per, type):
----------- self.rate = int(rate)
rate: :class:`int` self.per = float(per)
The total number of tokens available per :attr:`per` seconds. self.type = type
per: :class:`float` self._window = 0.0
The length of the cooldown period in seconds. self._tokens = self.rate
""" self._last = 0.0
__slots__ = ('rate', 'per', '_window', '_tokens', '_last') if not callable(self.type):
raise TypeError('Cooldown type must be a BucketType or callable')
def __init__(self, rate: float, per: float) -> None: def get_tokens(self, current=None):
self.rate: int = int(rate)
self.per: float = float(per)
self._window: float = 0.0
self._tokens: int = self.rate
self._last: float = 0.0
def get_tokens(self, current: Optional[float] = None) -> int:
"""Returns the number of available tokens before rate limiting is applied.
Parameters
------------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to calculate tokens at.
If not supplied then :func:`time.time()` is used.
Returns
--------
:class:`int`
The number of tokens available before the cooldown is to be applied.
"""
if not current: if not current:
current = time.time() current = time.time()
@@ -122,20 +92,7 @@ class Cooldown:
tokens = self.rate tokens = self.rate
return tokens return tokens
def get_retry_after(self, current: Optional[float] = None) -> float: def get_retry_after(self, current=None):
"""Returns the time in seconds until the cooldown will be reset.
Parameters
-------------
current: Optional[:class:`float`]
The current time in seconds since Unix epoch.
If not supplied, then :func:`time.time()` is used.
Returns
-------
:class:`float`
The number of seconds to wait before this cooldown will be reset.
"""
current = current or time.time() current = current or time.time()
tokens = self.get_tokens(current) tokens = self.get_tokens(current)
@@ -144,20 +101,7 @@ class Cooldown:
return 0.0 return 0.0
def update_rate_limit(self, current: Optional[float] = None) -> Optional[float]: def update_rate_limit(self, current=None):
"""Updates the cooldown rate limit.
Parameters
-------------
current: Optional[:class:`float`]
The time in seconds since Unix epoch to update the rate limit at.
If not supplied, then :func:`time.time()` is used.
Returns
-------
Optional[:class:`float`]
The retry-after time in seconds if rate limited.
"""
current = current or time.time() current = current or time.time()
self._last = current self._last = current
@@ -174,58 +118,43 @@ class Cooldown:
# we're not so decrement our tokens # we're not so decrement our tokens
self._tokens -= 1 self._tokens -= 1
def reset(self) -> None: # see if we got rate limited due to this token change, and if
"""Reset the cooldown to its initial state.""" # so update the window to point to our current time frame
if self._tokens == 0:
self._window = current
def reset(self):
self._tokens = self.rate self._tokens = self.rate
self._last = 0.0 self._last = 0.0
def copy(self) -> Cooldown: def copy(self):
"""Creates a copy of this cooldown. return Cooldown(self.rate, self.per, self.type)
Returns def __repr__(self):
-------- return '<Cooldown rate: {0.rate} per: {0.per} window: {0._window} tokens: {0._tokens}>'.format(self)
:class:`Cooldown`
A new instance of this cooldown.
"""
return Cooldown(self.rate, self.per)
def __repr__(self) -> str:
return f'<Cooldown rate: {self.rate} per: {self.per} window: {self._window} tokens: {self._tokens}>'
class CooldownMapping: class CooldownMapping:
def __init__( def __init__(self, original):
self, self._cache = {}
original: Optional[Cooldown], self._cooldown = original
type: Callable[[Message], Any],
) -> None:
if not callable(type):
raise TypeError('Cooldown type must be a BucketType or callable')
self._cache: Dict[Any, Cooldown] = {} def copy(self):
self._cooldown: Optional[Cooldown] = original ret = CooldownMapping(self._cooldown)
self._type: Callable[[Message], Any] = type
def copy(self) -> CooldownMapping:
ret = CooldownMapping(self._cooldown, self._type)
ret._cache = self._cache.copy() ret._cache = self._cache.copy()
return ret return ret
@property @property
def valid(self) -> bool: def valid(self):
return self._cooldown is not None return self._cooldown is not None
@property
def type(self) -> Callable[[Message], Any]:
return self._type
@classmethod @classmethod
def from_cooldown(cls: Type[C], rate, per, type) -> C: def from_cooldown(cls, rate, per, type):
return cls(Cooldown(rate, per), type) return cls(Cooldown(rate, per, type))
def _bucket_key(self, msg: Message) -> Any: def _bucket_key(self, msg):
return self._type(msg) return self._cooldown.type(msg)
def _verify_cache_integrity(self, current: Optional[float] = None) -> None: def _verify_cache_integrity(self, current=None):
# we want to delete all cache objects that haven't been used # we want to delete all cache objects that haven't been used
# in a cooldown window. e.g. if we have a command that has a # in a cooldown window. e.g. if we have a command that has a
# cooldown of 60s and it has not been used in 60s then that key should be deleted # cooldown of 60s and it has not been used in 60s then that key should be deleted
@@ -234,50 +163,24 @@ class CooldownMapping:
for k in dead_keys: for k in dead_keys:
del self._cache[k] del self._cache[k]
def create_bucket(self, message: Message) -> Cooldown: def get_bucket(self, message, current=None):
return self._cooldown.copy() # type: ignore if self._cooldown.type is BucketType.default:
return self._cooldown
def get_bucket(self, message: Message, current: Optional[float] = None) -> Cooldown:
if self._type is BucketType.default:
return self._cooldown # type: ignore
self._verify_cache_integrity(current) self._verify_cache_integrity(current)
key = self._bucket_key(message) key = self._bucket_key(message)
if key not in self._cache: if key not in self._cache:
bucket = self.create_bucket(message) bucket = self._cooldown.copy()
if bucket is not None: self._cache[key] = bucket
self._cache[key] = bucket
else: else:
bucket = self._cache[key] bucket = self._cache[key]
return bucket return bucket
def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]: def update_rate_limit(self, message, current=None):
bucket = self.get_bucket(message, current) bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current) return bucket.update_rate_limit(current)
class DynamicCooldownMapping(CooldownMapping):
def __init__(
self,
factory: Callable[[Message], Cooldown],
type: Callable[[Message], Any]
) -> None:
super().__init__(None, type)
self._factory: Callable[[Message], Cooldown] = factory
def copy(self) -> DynamicCooldownMapping:
ret = DynamicCooldownMapping(self._factory, self._type)
ret._cache = self._cache.copy()
return ret
@property
def valid(self) -> bool:
return True
def create_bucket(self, message: Message) -> Cooldown:
return self._factory(message)
class _Semaphore: class _Semaphore:
"""This class is a version of a semaphore. """This class is a version of a semaphore.
@@ -293,28 +196,28 @@ class _Semaphore:
__slots__ = ('value', 'loop', '_waiters') __slots__ = ('value', 'loop', '_waiters')
def __init__(self, number: int) -> None: def __init__(self, number):
self.value: int = number self.value = number
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self._waiters: Deque[asyncio.Future] = deque() self._waiters = deque()
def __repr__(self) -> str: def __repr__(self):
return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>' return '<_Semaphore value={0.value} waiters={1}>'.format(self, len(self._waiters))
def locked(self) -> bool: def locked(self):
return self.value == 0 return self.value == 0
def is_active(self) -> bool: def is_active(self):
return len(self._waiters) > 0 return len(self._waiters) > 0
def wake_up(self) -> None: def wake_up(self):
while self._waiters: while self._waiters:
future = self._waiters.popleft() future = self._waiters.popleft()
if not future.done(): if not future.done():
future.set_result(None) future.set_result(None)
return return
async def acquire(self, *, wait: bool = False) -> bool: async def acquire(self, *, wait=False):
if not wait and self.value <= 0: if not wait and self.value <= 0:
# signal that we're not acquiring # signal that we're not acquiring
return False return False
@@ -333,18 +236,18 @@ class _Semaphore:
self.value -= 1 self.value -= 1
return True return True
def release(self) -> None: def release(self):
self.value += 1 self.value += 1
self.wake_up() self.wake_up()
class MaxConcurrency: class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping') __slots__ = ('number', 'per', 'wait', '_mapping')
def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: def __init__(self, number, *, per, wait):
self._mapping: Dict[Any, _Semaphore] = {} self._mapping = {}
self.per: BucketType = per self.per = per
self.number: int = number self.number = number
self.wait: bool = wait self.wait = wait
if number <= 0: if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1') raise ValueError('max_concurrency \'number\' cannot be less than 1')
@@ -352,16 +255,16 @@ class MaxConcurrency:
if not isinstance(per, BucketType): if not isinstance(per, BucketType):
raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}') raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}')
def copy(self: MC) -> MC: def copy(self):
return self.__class__(self.number, per=self.per, wait=self.wait) return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self) -> str: def __repr__(self):
return f'<MaxConcurrency per={self.per!r} number={self.number} wait={self.wait}>' return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
def get_key(self, message: Message) -> Any: def get_key(self, message):
return self.per.get_key(message) return self.per.get_key(message)
async def acquire(self, message: Message) -> None: async def acquire(self, message):
key = self.get_key(message) key = self.get_key(message)
try: try:
@@ -373,7 +276,7 @@ class MaxConcurrency:
if not acquired: if not acquired:
raise MaxConcurrencyReached(self.number, self.per) raise MaxConcurrencyReached(self.number, self.per)
async def release(self, message: Message) -> None: async def release(self, message):
# Technically there's no reason for this function to be async # Technically there's no reason for this function to be async
# But it might be more useful in the future # But it might be more useful in the future
key = self.get_key(message) key = self.get_key(message)

File diff suppressed because it is too large Load Diff

View File

@@ -22,23 +22,8 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Optional, Any, TYPE_CHECKING, List, Callable, Type, Tuple, Union
from discord.errors import ClientException, DiscordException from discord.errors import ClientException, DiscordException
if TYPE_CHECKING:
from inspect import Parameter
from .converter import Converter
from .context import Context
from .cooldowns import Cooldown, BucketType
from .flags import Flag
from discord.abc import GuildChannel
from discord.threads import Thread
from discord.types.snowflake import Snowflake, SnowflakeList
__all__ = ( __all__ = (
'CommandError', 'CommandError',
@@ -57,19 +42,15 @@ __all__ = (
'MaxConcurrencyReached', 'MaxConcurrencyReached',
'NotOwner', 'NotOwner',
'MessageNotFound', 'MessageNotFound',
'ObjectNotFound',
'MemberNotFound', 'MemberNotFound',
'GuildNotFound', 'GuildNotFound',
'UserNotFound', 'UserNotFound',
'ChannelNotFound', 'ChannelNotFound',
'ThreadNotFound',
'ChannelNotReadable', 'ChannelNotReadable',
'BadColourArgument', 'BadColourArgument',
'BadColorArgument',
'RoleNotFound', 'RoleNotFound',
'BadInviteArgument', 'BadInviteArgument',
'EmojiNotFound', 'EmojiNotFound',
'GuildStickerNotFound',
'PartialEmojiConversionFailure', 'PartialEmojiConversionFailure',
'BadBoolArgument', 'BadBoolArgument',
'MissingRole', 'MissingRole',
@@ -81,7 +62,6 @@ __all__ = (
'NSFWChannelRequired', 'NSFWChannelRequired',
'ConversionError', 'ConversionError',
'BadUnionArgument', 'BadUnionArgument',
'BadLiteralArgument',
'ArgumentParsingError', 'ArgumentParsingError',
'UnexpectedQuoteError', 'UnexpectedQuoteError',
'InvalidEndOfQuotedStringError', 'InvalidEndOfQuotedStringError',
@@ -93,11 +73,6 @@ __all__ = (
'ExtensionFailed', 'ExtensionFailed',
'ExtensionNotFound', 'ExtensionNotFound',
'CommandRegistrationError', 'CommandRegistrationError',
'FlagError',
'BadFlagArgument',
'MissingFlagArgument',
'TooManyFlags',
'MissingRequiredFlag',
) )
class CommandError(DiscordException): class CommandError(DiscordException):
@@ -107,9 +82,9 @@ class CommandError(DiscordException):
This exception and exceptions inherited from it are handled This exception and exceptions inherited from it are handled
in a special way as they are caught and passed into a special event in a special way as they are caught and passed into a special event
from :class:`.Bot`\, :func:`.on_command_error`. from :class:`.Bot`\, :func:`on_command_error`.
""" """
def __init__(self, message: Optional[str] = None, *args: Any) -> None: def __init__(self, message=None, *args):
if message is not None: if message is not None:
# clean-up @everyone and @here mentions # clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
@@ -130,9 +105,9 @@ class ConversionError(CommandError):
The original exception that was raised. You can also get this via The original exception that was raised. You can also get this via
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, converter: Converter, original: Exception) -> None: def __init__(self, converter, original):
self.converter: Converter = converter self.converter = converter
self.original: Exception = original self.original = original
class UserInputError(CommandError): class UserInputError(CommandError):
"""The base exception type for errors that involve errors """The base exception type for errors that involve errors
@@ -164,8 +139,8 @@ class MissingRequiredArgument(UserInputError):
param: :class:`inspect.Parameter` param: :class:`inspect.Parameter`
The argument that is missing. The argument that is missing.
""" """
def __init__(self, param: Parameter) -> None: def __init__(self, param):
self.param: Parameter = param self.param = param
super().__init__(f'{param.name} is a required argument that is missing.') super().__init__(f'{param.name} is a required argument that is missing.')
class TooManyArguments(UserInputError): class TooManyArguments(UserInputError):
@@ -206,9 +181,9 @@ class CheckAnyFailure(CheckFailure):
A list of check predicates that failed. A list of check predicates that failed.
""" """
def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None: def __init__(self, checks, errors):
self.checks: List[CheckFailure] = checks self.checks = checks
self.errors: List[Callable[[Context], bool]] = errors self.errors = errors
super().__init__('You do not have permission to run this command.') super().__init__('You do not have permission to run this command.')
class PrivateMessageOnly(CheckFailure): class PrivateMessageOnly(CheckFailure):
@@ -217,7 +192,7 @@ class PrivateMessageOnly(CheckFailure):
This inherits from :exc:`CheckFailure` This inherits from :exc:`CheckFailure`
""" """
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message=None):
super().__init__(message or 'This command can only be used in private messages.') super().__init__(message or 'This command can only be used in private messages.')
class NoPrivateMessage(CheckFailure): class NoPrivateMessage(CheckFailure):
@@ -227,7 +202,7 @@ class NoPrivateMessage(CheckFailure):
This inherits from :exc:`CheckFailure` This inherits from :exc:`CheckFailure`
""" """
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message=None):
super().__init__(message or 'This command cannot be used in private messages.') super().__init__(message or 'This command cannot be used in private messages.')
class NotOwner(CheckFailure): class NotOwner(CheckFailure):
@@ -237,23 +212,6 @@ class NotOwner(CheckFailure):
""" """
pass pass
class ObjectNotFound(BadArgument):
"""Exception raised when the argument provided did not match the format
of an ID or a mention.
This inherits from :exc:`BadArgument`
.. versionadded:: 2.0
Attributes
-----------
argument: :class:`str`
The argument supplied by the caller that was not matched
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'{argument!r} does not follow a valid ID or mention format.')
class MemberNotFound(BadArgument): class MemberNotFound(BadArgument):
"""Exception raised when the member provided was not found in the bot's """Exception raised when the member provided was not found in the bot's
cache. cache.
@@ -267,8 +225,8 @@ class MemberNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The member supplied by the caller that was not found The member supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'Member "{argument}" not found.') super().__init__(f'Member "{argument}" not found.')
class GuildNotFound(BadArgument): class GuildNotFound(BadArgument):
@@ -283,8 +241,8 @@ class GuildNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The guild supplied by the called that was not found The guild supplied by the called that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'Guild "{argument}" not found.') super().__init__(f'Guild "{argument}" not found.')
class UserNotFound(BadArgument): class UserNotFound(BadArgument):
@@ -300,8 +258,8 @@ class UserNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The user supplied by the caller that was not found The user supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'User "{argument}" not found.') super().__init__(f'User "{argument}" not found.')
class MessageNotFound(BadArgument): class MessageNotFound(BadArgument):
@@ -316,8 +274,8 @@ class MessageNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The message supplied by the caller that was not found The message supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'Message "{argument}" not found.') super().__init__(f'Message "{argument}" not found.')
class ChannelNotReadable(BadArgument): class ChannelNotReadable(BadArgument):
@@ -330,11 +288,11 @@ class ChannelNotReadable(BadArgument):
Attributes Attributes
----------- -----------
argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] argument: :class:`.abc.GuildChannel`
The channel supplied by the caller that was not readable The channel supplied by the caller that was not readable
""" """
def __init__(self, argument: Union[GuildChannel, Thread]) -> None: def __init__(self, argument):
self.argument: Union[GuildChannel, Thread] = argument self.argument = argument
super().__init__(f"Can't read messages in {argument.mention}.") super().__init__(f"Can't read messages in {argument.mention}.")
class ChannelNotFound(BadArgument): class ChannelNotFound(BadArgument):
@@ -349,26 +307,10 @@ class ChannelNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The channel supplied by the caller that was not found The channel supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'Channel "{argument}" not found.') super().__init__(f'Channel "{argument}" not found.')
class ThreadNotFound(BadArgument):
"""Exception raised when the bot can not find the thread.
This inherits from :exc:`BadArgument`
.. versionadded:: 2.0
Attributes
-----------
argument: :class:`str`
The thread supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Thread "{argument}" not found.')
class BadColourArgument(BadArgument): class BadColourArgument(BadArgument):
"""Exception raised when the colour is not valid. """Exception raised when the colour is not valid.
@@ -381,8 +323,8 @@ class BadColourArgument(BadArgument):
argument: :class:`str` argument: :class:`str`
The colour supplied by the caller that was not valid The colour supplied by the caller that was not valid
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'Colour "{argument}" is invalid.') super().__init__(f'Colour "{argument}" is invalid.')
BadColorArgument = BadColourArgument BadColorArgument = BadColourArgument
@@ -399,8 +341,8 @@ class RoleNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The role supplied by the caller that was not found The role supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'Role "{argument}" not found.') super().__init__(f'Role "{argument}" not found.')
class BadInviteArgument(BadArgument): class BadInviteArgument(BadArgument):
@@ -410,9 +352,8 @@ class BadInviteArgument(BadArgument):
.. versionadded:: 1.5 .. versionadded:: 1.5
""" """
def __init__(self, argument: str) -> None: def __init__(self):
self.argument: str = argument super().__init__('Invite is invalid or expired.')
super().__init__(f'Invite "{argument}" is invalid or expired.')
class EmojiNotFound(BadArgument): class EmojiNotFound(BadArgument):
"""Exception raised when the bot can not find the emoji. """Exception raised when the bot can not find the emoji.
@@ -426,8 +367,8 @@ class EmojiNotFound(BadArgument):
argument: :class:`str` argument: :class:`str`
The emoji supplied by the caller that was not found The emoji supplied by the caller that was not found
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'Emoji "{argument}" not found.') super().__init__(f'Emoji "{argument}" not found.')
class PartialEmojiConversionFailure(BadArgument): class PartialEmojiConversionFailure(BadArgument):
@@ -443,26 +384,10 @@ class PartialEmojiConversionFailure(BadArgument):
argument: :class:`str` argument: :class:`str`
The emoji supplied by the caller that did not match the regex The emoji supplied by the caller that did not match the regex
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.') super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.')
class GuildStickerNotFound(BadArgument):
"""Exception raised when the bot can not find the sticker.
This inherits from :exc:`BadArgument`
.. versionadded:: 2.0
Attributes
-----------
argument: :class:`str`
The sticker supplied by the caller that was not found
"""
def __init__(self, argument: str) -> None:
self.argument: str = argument
super().__init__(f'Sticker "{argument}" not found.')
class BadBoolArgument(BadArgument): class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable. """Exception raised when a boolean argument was not convertable.
@@ -475,8 +400,8 @@ class BadBoolArgument(BadArgument):
argument: :class:`str` argument: :class:`str`
The boolean argument supplied by the caller that is not in the predefined list The boolean argument supplied by the caller that is not in the predefined list
""" """
def __init__(self, argument: str) -> None: def __init__(self, argument):
self.argument: str = argument self.argument = argument
super().__init__(f'{argument} is not a recognised boolean option') super().__init__(f'{argument} is not a recognised boolean option')
class DisabledCommand(CommandError): class DisabledCommand(CommandError):
@@ -497,9 +422,9 @@ class CommandInvokeError(CommandError):
The original exception that was raised. You can also get this via The original exception that was raised. You can also get this via
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, e: Exception) -> None: def __init__(self, e):
self.original: Exception = e self.original = e
super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}') super().__init__('Command raised an exception: {0.__class__.__name__}: {0}'.format(e))
class CommandOnCooldown(CommandError): class CommandOnCooldown(CommandError):
"""Exception raised when the command being invoked is on cooldown. """Exception raised when the command being invoked is on cooldown.
@@ -508,18 +433,15 @@ class CommandOnCooldown(CommandError):
Attributes Attributes
----------- -----------
cooldown: :class:`.Cooldown` cooldown: Cooldown
A class with attributes ``rate`` and ``per`` similar to the A class with attributes ``rate``, ``per``, and ``type`` similar to
:func:`.cooldown` decorator. the :func:`.cooldown` decorator.
type: :class:`BucketType`
The type associated with the cooldown.
retry_after: :class:`float` retry_after: :class:`float`
The amount of seconds to wait before you can retry again. The amount of seconds to wait before you can retry again.
""" """
def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: def __init__(self, cooldown, retry_after):
self.cooldown: Cooldown = cooldown self.cooldown = cooldown
self.retry_after: float = retry_after self.retry_after = retry_after
self.type: BucketType = type
super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s') super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s')
class MaxConcurrencyReached(CommandError): class MaxConcurrencyReached(CommandError):
@@ -535,14 +457,14 @@ class MaxConcurrencyReached(CommandError):
The bucket type passed to the :func:`.max_concurrency` decorator. The bucket type passed to the :func:`.max_concurrency` decorator.
""" """
def __init__(self, number: int, per: BucketType) -> None: def __init__(self, number, per):
self.number: int = number self.number = number
self.per: BucketType = per self.per = per
name = per.name name = per.name
suffix = 'per %s' % name if per.name != 'default' else 'globally' suffix = 'per %s' % name if per.name != 'default' else 'globally'
plural = '%s times %s' if number > 1 else '%s time %s' plural = '%s times %s' if number > 1 else '%s time %s'
fmt = plural % (number, suffix) fmt = plural % (number, suffix)
super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.') super().__init__(f'Too many people using this command. It can only be used {fmt} concurrently.')
class MissingRole(CheckFailure): class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command. """Exception raised when the command invoker lacks a role to run a command.
@@ -557,8 +479,8 @@ class MissingRole(CheckFailure):
The required role that is missing. The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`. This is the parameter passed to :func:`~.commands.has_role`.
""" """
def __init__(self, missing_role: Snowflake) -> None: def __init__(self, missing_role):
self.missing_role: Snowflake = missing_role self.missing_role = missing_role
message = f'Role {missing_role!r} is required to run this command.' message = f'Role {missing_role!r} is required to run this command.'
super().__init__(message) super().__init__(message)
@@ -575,8 +497,8 @@ class BotMissingRole(CheckFailure):
The required role that is missing. The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`. This is the parameter passed to :func:`~.commands.has_role`.
""" """
def __init__(self, missing_role: Snowflake) -> None: def __init__(self, missing_role):
self.missing_role: Snowflake = missing_role self.missing_role = missing_role
message = f'Bot requires the role {missing_role!r} to run this command' message = f'Bot requires the role {missing_role!r} to run this command'
super().__init__(message) super().__init__(message)
@@ -594,8 +516,8 @@ class MissingAnyRole(CheckFailure):
The roles that the invoker is missing. The roles that the invoker is missing.
These are the parameters passed to :func:`~.commands.has_any_role`. These are the parameters passed to :func:`~.commands.has_any_role`.
""" """
def __init__(self, missing_roles: SnowflakeList) -> None: def __init__(self, missing_roles):
self.missing_roles: SnowflakeList = missing_roles self.missing_roles = missing_roles
missing = [f"'{role}'" for role in missing_roles] missing = [f"'{role}'" for role in missing_roles]
@@ -623,8 +545,8 @@ class BotMissingAnyRole(CheckFailure):
These are the parameters passed to :func:`~.commands.has_any_role`. These are the parameters passed to :func:`~.commands.has_any_role`.
""" """
def __init__(self, missing_roles: SnowflakeList) -> None: def __init__(self, missing_roles):
self.missing_roles: SnowflakeList = missing_roles self.missing_roles = missing_roles
missing = [f"'{role}'" for role in missing_roles] missing = [f"'{role}'" for role in missing_roles]
@@ -645,11 +567,11 @@ class NSFWChannelRequired(CheckFailure):
Parameters Parameters
----------- -----------
channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] channel: :class:`discord.abc.GuildChannel`
The channel that does not have NSFW enabled. The channel that does not have NSFW enabled.
""" """
def __init__(self, channel: Union[GuildChannel, Thread]) -> None: def __init__(self, channel):
self.channel: Union[GuildChannel, Thread] = channel self.channel = channel
super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.") super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.")
class MissingPermissions(CheckFailure): class MissingPermissions(CheckFailure):
@@ -660,13 +582,13 @@ class MissingPermissions(CheckFailure):
Attributes Attributes
----------- -----------
missing_permissions: List[:class:`str`] missing_perms: :class:`list`
The required permissions that are missing. The required permissions that are missing.
""" """
def __init__(self, missing_permissions: List[str], *args: Any) -> None: def __init__(self, missing_perms, *args):
self.missing_permissions: List[str] = missing_permissions self.missing_perms = missing_perms
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_perms]
if len(missing) > 2: if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
@@ -683,13 +605,13 @@ class BotMissingPermissions(CheckFailure):
Attributes Attributes
----------- -----------
missing_permissions: List[:class:`str`] missing_perms: :class:`list`
The required permissions that are missing. The required permissions that are missing.
""" """
def __init__(self, missing_permissions: List[str], *args: Any) -> None: def __init__(self, missing_perms, *args):
self.missing_permissions: List[str] = missing_permissions self.missing_perms = missing_perms
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_perms]
if len(missing) > 2: if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
@@ -708,22 +630,20 @@ class BadUnionArgument(UserInputError):
----------- -----------
param: :class:`inspect.Parameter` param: :class:`inspect.Parameter`
The parameter that failed being converted. The parameter that failed being converted.
converters: Tuple[Type, ``...``] converters: Tuple[Type, ...]
A tuple of converters attempted in conversion, in order of failure. A tuple of converters attempted in conversion, in order of failure.
errors: List[:class:`CommandError`] errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion. A list of errors that were caught from failing the conversion.
""" """
def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: def __init__(self, param, converters, errors):
self.param: Parameter = param self.param = param
self.converters: Tuple[Type, ...] = converters self.converters = converters
self.errors: List[CommandError] = errors self.errors = errors
def _get_name(x): def _get_name(x):
try: try:
return x.__name__ return x.__name__
except AttributeError: except AttributeError:
if hasattr(x, '__origin__'):
return repr(x)
return x.__class__.__name__ return x.__class__.__name__
to_string = [_get_name(x) for x in converters] to_string = [_get_name(x) for x in converters]
@@ -734,36 +654,6 @@ class BadUnionArgument(UserInputError):
super().__init__(f'Could not convert "{param.name}" into {fmt}.') super().__init__(f'Could not convert "{param.name}" into {fmt}.')
class BadLiteralArgument(UserInputError):
"""Exception raised when a :data:`typing.Literal` converter fails for all
its associated values.
This inherits from :exc:`UserInputError`
.. versionadded:: 2.0
Attributes
-----------
param: :class:`inspect.Parameter`
The parameter that failed being converted.
literals: Tuple[Any, ``...``]
A tuple of values compared against in conversion, in order of failure.
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None:
self.param: Parameter = param
self.literals: Tuple[Any, ...] = literals
self.errors: List[CommandError] = errors
to_string = [repr(l) for l in literals]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.')
class ArgumentParsingError(UserInputError): class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input. """An exception raised when the parser fails to parse a user's input.
@@ -784,8 +674,8 @@ class UnexpectedQuoteError(ArgumentParsingError):
quote: :class:`str` quote: :class:`str`
The quote mark that was found inside the non-quoted string. The quote mark that was found inside the non-quoted string.
""" """
def __init__(self, quote: str) -> None: def __init__(self, quote):
self.quote: str = quote self.quote = quote
super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string') super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string')
class InvalidEndOfQuotedStringError(ArgumentParsingError): class InvalidEndOfQuotedStringError(ArgumentParsingError):
@@ -799,8 +689,8 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError):
char: :class:`str` char: :class:`str`
The character found instead of the expected string. The character found instead of the expected string.
""" """
def __init__(self, char: str) -> None: def __init__(self, char):
self.char: str = char self.char = char
super().__init__(f'Expected space after closing quotation but received {char!r}') super().__init__(f'Expected space after closing quotation but received {char!r}')
class ExpectedClosingQuoteError(ArgumentParsingError): class ExpectedClosingQuoteError(ArgumentParsingError):
@@ -814,8 +704,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError):
The quote character expected. The quote character expected.
""" """
def __init__(self, close_quote: str) -> None: def __init__(self, close_quote):
self.close_quote: str = close_quote self.close_quote = close_quote
super().__init__(f'Expected closing {close_quote}.') super().__init__(f'Expected closing {close_quote}.')
class ExtensionError(DiscordException): class ExtensionError(DiscordException):
@@ -828,8 +718,8 @@ class ExtensionError(DiscordException):
name: :class:`str` name: :class:`str`
The extension that had an error. The extension that had an error.
""" """
def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None: def __init__(self, message=None, *args, name):
self.name: str = name self.name = name
message = message or f'Extension {name!r} had an error.' message = message or f'Extension {name!r} had an error.'
# clean-up @everyone and @here mentions # clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
@@ -840,7 +730,7 @@ class ExtensionAlreadyLoaded(ExtensionError):
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
""" """
def __init__(self, name: str) -> None: def __init__(self, name):
super().__init__(f'Extension {name!r} is already loaded.', name=name) super().__init__(f'Extension {name!r} is already loaded.', name=name)
class ExtensionNotLoaded(ExtensionError): class ExtensionNotLoaded(ExtensionError):
@@ -848,7 +738,7 @@ class ExtensionNotLoaded(ExtensionError):
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
""" """
def __init__(self, name: str) -> None: def __init__(self, name):
super().__init__(f'Extension {name!r} has not been loaded.', name=name) super().__init__(f'Extension {name!r} has not been loaded.', name=name)
class NoEntryPointError(ExtensionError): class NoEntryPointError(ExtensionError):
@@ -856,7 +746,7 @@ class NoEntryPointError(ExtensionError):
This inherits from :exc:`ExtensionError` This inherits from :exc:`ExtensionError`
""" """
def __init__(self, name: str) -> None: def __init__(self, name):
super().__init__(f"Extension {name!r} has no 'setup' function.", name=name) super().__init__(f"Extension {name!r} has no 'setup' function.", name=name)
class ExtensionFailed(ExtensionError): class ExtensionFailed(ExtensionError):
@@ -872,10 +762,10 @@ class ExtensionFailed(ExtensionError):
The original exception that was raised. You can also get this via The original exception that was raised. You can also get this via
the ``__cause__`` attribute. the ``__cause__`` attribute.
""" """
def __init__(self, name: str, original: Exception) -> None: def __init__(self, name, original):
self.original: Exception = original self.original = original
msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}' fmt = 'Extension {0!r} raised an error: {1.__class__.__name__}: {1}'
super().__init__(msg, name=name) super().__init__(fmt.format(name, original), name=name)
class ExtensionNotFound(ExtensionError): class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension is not found. """An exception raised when an extension is not found.
@@ -889,10 +779,13 @@ class ExtensionNotFound(ExtensionError):
----------- -----------
name: :class:`str` name: :class:`str`
The extension that had the error. The extension that had the error.
original: :class:`NoneType`
Always ``None`` for backwards compatibility.
""" """
def __init__(self, name: str) -> None: def __init__(self, name, original=None):
msg = f'Extension {name!r} could not be loaded.' self.original = None
super().__init__(msg, name=name) fmt = 'Extension {0!r} could not be loaded.'
super().__init__(fmt.format(name), name=name)
class CommandRegistrationError(ClientException): class CommandRegistrationError(ClientException):
"""An exception raised when the command can't be added """An exception raised when the command can't be added
@@ -909,89 +802,8 @@ class CommandRegistrationError(ClientException):
alias_conflict: :class:`bool` alias_conflict: :class:`bool`
Whether the name that conflicts is an alias of the command we try to add. Whether the name that conflicts is an alias of the command we try to add.
""" """
def __init__(self, name: str, *, alias_conflict: bool = False) -> None: def __init__(self, name, *, alias_conflict=False):
self.name: str = name self.name = name
self.alias_conflict: bool = alias_conflict self.alias_conflict = alias_conflict
type_ = 'alias' if alias_conflict else 'command' type_ = 'alias' if alias_conflict else 'command'
super().__init__(f'The {type_} {name} is already an existing command or alias.') super().__init__(f'The {type_} {name} is already an existing command or alias.')
class FlagError(BadArgument):
"""The base exception type for all flag parsing related errors.
This inherits from :exc:`BadArgument`.
.. versionadded:: 2.0
"""
pass
class TooManyFlags(FlagError):
"""An exception raised when a flag has received too many values.
This inherits from :exc:`FlagError`.
.. versionadded:: 2.0
Attributes
------------
flag: :class:`~discord.ext.commands.Flag`
The flag that received too many values.
values: List[:class:`str`]
The values that were passed.
"""
def __init__(self, flag: Flag, values: List[str]) -> None:
self.flag: Flag = flag
self.values: List[str] = values
super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.')
class BadFlagArgument(FlagError):
"""An exception raised when a flag failed to convert a value.
This inherits from :exc:`FlagError`
.. versionadded:: 2.0
Attributes
-----------
flag: :class:`~discord.ext.commands.Flag`
The flag that failed to convert.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
try:
name = flag.annotation.__name__
except AttributeError:
name = flag.annotation.__class__.__name__
super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}')
class MissingRequiredFlag(FlagError):
"""An exception raised when a required flag was not given.
This inherits from :exc:`FlagError`
.. versionadded:: 2.0
Attributes
-----------
flag: :class:`~discord.ext.commands.Flag`
The required flag that was not found.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} is required and missing')
class MissingFlagArgument(FlagError):
"""An exception raised when a flag did not get a value.
This inherits from :exc:`FlagError`
.. versionadded:: 2.0
Attributes
-----------
flag: :class:`~discord.ext.commands.Flag`
The flag that did not get a value.
"""
def __init__(self, flag: Flag) -> None:
self.flag: Flag = flag
super().__init__(f'Flag {flag.name!r} does not have an argument')

View File

@@ -1,618 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from .errors import (
BadFlagArgument,
CommandError,
MissingFlagArgument,
TooManyFlags,
MissingRequiredFlag,
)
from discord.utils import resolve_annotation
from .view import StringView
from .converter import run_converters
from discord.utils import maybe_coroutine, MISSING
from dataclasses import dataclass, field
from typing import (
Dict,
Iterator,
Literal,
Optional,
Pattern,
Set,
TYPE_CHECKING,
Tuple,
List,
Any,
Type,
TypeVar,
Union,
)
import inspect
import sys
import re
__all__ = (
'Flag',
'flag',
'FlagConverter',
)
if TYPE_CHECKING:
from .context import Context
@dataclass
class Flag:
"""Represents a flag parameter for :class:`FlagConverter`.
The :func:`~discord.ext.commands.flag` function helps
create these flag objects, but it is not necessary to
do so. These cannot be constructed manually.
Attributes
------------
name: :class:`str`
The name of the flag.
aliases: List[:class:`str`]
The aliases of the flag name.
attribute: :class:`str`
The attribute in the class that corresponds to this flag.
default: Any
The default value of the flag, if available.
annotation: Any
The underlying evaluated annotation of the flag.
max_args: :class:`int`
The maximum number of arguments the flag can accept.
A negative value indicates an unlimited amount of arguments.
override: :class:`bool`
Whether multiple given values overrides the previous value.
"""
name: str = MISSING
aliases: List[str] = field(default_factory=list)
attribute: str = MISSING
annotation: Any = MISSING
default: Any = MISSING
max_args: int = MISSING
override: bool = MISSING
cast_to_dict: bool = False
@property
def required(self) -> bool:
""":class:`bool`: Whether the flag is required.
A required flag has no default value.
"""
return self.default is MISSING
def flag(
*,
name: str = MISSING,
aliases: List[str] = MISSING,
default: Any = MISSING,
max_args: int = MISSING,
override: bool = MISSING,
) -> Any:
"""Override default functionality and parameters of the underlying :class:`FlagConverter`
class attributes.
Parameters
------------
name: :class:`str`
The flag name. If not given, defaults to the attribute name.
aliases: List[:class:`str`]
Aliases to the flag name. If not given no aliases are set.
default: Any
The default parameter. This could be either a value or a callable that takes
:class:`Context` as its sole parameter. If not given then it defaults to
the default value given to the attribute.
max_args: :class:`int`
The maximum number of arguments the flag can accept.
A negative value indicates an unlimited amount of arguments.
The default value depends on the annotation given.
override: :class:`bool`
Whether multiple given values overrides the previous value. The default
value depends on the annotation given.
"""
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
def validate_flag_name(name: str, forbidden: Set[str]):
if not name:
raise ValueError('flag names should not be empty')
for ch in name:
if ch.isspace():
raise ValueError(f'flag name {name!r} cannot have spaces')
if ch == '\\':
raise ValueError(f'flag name {name!r} cannot have backslashes')
if ch in forbidden:
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
annotations = namespace.get('__annotations__', {})
case_insensitive = namespace['__commands_flag_case_insensitive__']
flags: Dict[str, Flag] = {}
cache: Dict[str, Any] = {}
names: Set[str] = set()
for name, annotation in annotations.items():
flag = namespace.pop(name, MISSING)
if isinstance(flag, Flag):
flag.annotation = annotation
else:
flag = Flag(name=name, annotation=annotation, default=flag)
flag.attribute = name
if flag.name is MISSING:
flag.name = name
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible():
flag.default = annotation._construct_default
if flag.aliases is MISSING:
flag.aliases = []
# Add sensible defaults based off of the type annotation
# <type> -> (max_args=1)
# List[str] -> (max_args=-1)
# Tuple[int, ...] -> (max_args=1)
# Dict[K, V] -> (max_args=-1, override=True)
# Union[str, int] -> (max_args=1)
# Optional[str] -> (default=None, max_args=1)
try:
origin = annotation.__origin__
except AttributeError:
# A regular type hint
if flag.max_args is MISSING:
flag.max_args = 1
else:
if origin is Union:
# typing.Union
if flag.max_args is MISSING:
flag.max_args = 1
if annotation.__args__[-1] is type(None) and flag.default is MISSING:
# typing.Optional
flag.default = None
elif origin is tuple:
# typing.Tuple
# tuple parsing is e.g. `flag: peter 20`
# for Tuple[str, int] would give you flag: ('peter', 20)
if flag.max_args is MISSING:
flag.max_args = 1
elif origin is list:
# typing.List
if flag.max_args is MISSING:
flag.max_args = -1
elif origin is dict:
# typing.Dict[K, V]
# Equivalent to:
# typing.List[typing.Tuple[K, V]]
flag.cast_to_dict = True
if flag.max_args is MISSING:
flag.max_args = -1
if flag.override is MISSING:
flag.override = True
elif origin is Literal:
if flag.max_args is MISSING:
flag.max_args = 1
else:
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
if flag.override is MISSING:
flag.override = False
# Validate flag names are unique
name = flag.name.casefold() if case_insensitive else flag.name
if name in names:
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
else:
names.add(name)
for alias in flag.aliases:
# Validate alias is unique
alias = alias.casefold() if case_insensitive else alias
if alias in names:
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
else:
names.add(alias)
flags[flag.name] = flag
return flags
class FlagsMeta(type):
if TYPE_CHECKING:
__commands_is_flag__: bool
__commands_flags__: Dict[str, Flag]
__commands_flag_aliases__: Dict[str, str]
__commands_flag_regex__: Pattern[str]
__commands_flag_case_insensitive__: bool
__commands_flag_delimiter__: str
__commands_flag_prefix__: str
def __new__(
cls: Type[type],
name: str,
bases: Tuple[type, ...],
attrs: Dict[str, Any],
*,
case_insensitive: bool = MISSING,
delimiter: str = MISSING,
prefix: str = MISSING,
):
attrs['__commands_is_flag__'] = True
try:
global_ns = sys.modules[attrs['__module__']].__dict__
except KeyError:
global_ns = {}
frame = inspect.currentframe()
try:
if frame is None:
local_ns = {}
else:
if frame.f_back is None:
local_ns = frame.f_locals
else:
local_ns = frame.f_back.f_locals
finally:
del frame
flags: Dict[str, Flag] = {}
aliases: Dict[str, str] = {}
for base in reversed(bases):
if base.__dict__.get('__commands_is_flag__', False):
flags.update(base.__dict__['__commands_flags__'])
aliases.update(base.__dict__['__commands_flag_aliases__'])
if case_insensitive is MISSING:
attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__']
if delimiter is MISSING:
attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__']
if prefix is MISSING:
attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__']
if case_insensitive is not MISSING:
attrs['__commands_flag_case_insensitive__'] = case_insensitive
if delimiter is not MISSING:
attrs['__commands_flag_delimiter__'] = delimiter
if prefix is not MISSING:
attrs['__commands_flag_prefix__'] = prefix
case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False)
delimiter = attrs.setdefault('__commands_flag_delimiter__', ':')
prefix = attrs.setdefault('__commands_flag_prefix__', '')
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
flags[flag_name] = flag
aliases.update({alias_name: flag_name for alias_name in flag.aliases})
forbidden = set(delimiter).union(prefix)
for flag_name in flags:
validate_flag_name(flag_name, forbidden)
for alias_name in aliases:
validate_flag_name(alias_name, forbidden)
regex_flags = 0
if case_insensitive:
flags = {key.casefold(): value for key, value in flags.items()}
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()}
regex_flags = re.IGNORECASE
keys = list(re.escape(k) for k in flags)
keys.extend(re.escape(a) for a in aliases)
keys = sorted(keys, key=lambda t: len(t), reverse=True)
joined = '|'.join(keys)
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
attrs['__commands_flag_regex__'] = pattern
attrs['__commands_flags__'] = flags
attrs['__commands_flag_aliases__'] = aliases
return type.__new__(cls, name, bases, attrs)
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
while not view.eof:
view.skip_ws()
if view.eof:
break
word = view.get_quoted_word()
if word is None:
break
try:
converted = await run_converters(ctx, converter, word, param)
except CommandError:
raise
except Exception as e:
raise BadFlagArgument(flag) from e
else:
results.append(converted)
return tuple(results)
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
view = StringView(argument)
results = []
param: inspect.Parameter = ctx.current_parameter # type: ignore
for converter in converters:
view.skip_ws()
if view.eof:
break
word = view.get_quoted_word()
if word is None:
break
try:
converted = await run_converters(ctx, converter, word, param)
except CommandError:
raise
except Exception as e:
raise BadFlagArgument(flag) from e
else:
results.append(converted)
if len(results) != len(converters):
raise BadFlagArgument(flag)
return tuple(results)
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any:
param: inspect.Parameter = ctx.current_parameter # type: ignore
annotation = annotation or flag.annotation
try:
origin = annotation.__origin__
except AttributeError:
pass
else:
if origin is tuple:
if annotation.__args__[-1] is Ellipsis:
return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0])
else:
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)
elif origin is list:
# typing.List[x]
annotation = annotation.__args__[0]
return await convert_flag(ctx, argument, flag, annotation)
elif origin is Union and annotation.__args__[-1] is type(None):
# typing.Optional[x]
annotation = Union[annotation.__args__[:-1]]
return await run_converters(ctx, annotation, argument, param)
elif origin is dict:
# typing.Dict[K, V] -> typing.Tuple[K, V]
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)
try:
return await run_converters(ctx, annotation, argument, param)
except CommandError:
raise
except Exception as e:
raise BadFlagArgument(flag) from e
F = TypeVar('F', bound='FlagConverter')
class FlagConverter(metaclass=FlagsMeta):
"""A converter that allows for a user-friendly flag syntax.
The flags are defined using :pep:`526` type annotations similar
to the :mod:`dataclasses` Python module. For more information on
how this converter works, check the appropriate
:ref:`documentation <ext_commands_flag_converter>`.
.. container:: operations
.. describe:: iter(x)
Returns an iterator of ``(flag_name, flag_value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. versionadded:: 2.0
Parameters
-----------
case_insensitive: :class:`bool`
A class parameter to toggle case insensitivity of the flag parsing.
If ``True`` then flags are parsed in a case insensitive manner.
Defaults to ``False``.
prefix: :class:`str`
The prefix that all flags must be prefixed with. By default
there is no prefix.
delimiter: :class:`str`
The delimiter that separates a flag's argument from the flag's name.
By default this is ``:``.
"""
@classmethod
def get_flags(cls) -> Dict[str, Flag]:
"""Dict[:class:`str`, :class:`Flag`]: A mapping of flag name to flag object this converter has."""
return cls.__commands_flags__.copy()
@classmethod
def _can_be_constructible(cls) -> bool:
return all(not flag.required for flag in cls.__commands_flags__.values())
def __iter__(self) -> Iterator[Tuple[str, Any]]:
for flag in self.__class__.__commands_flags__.values():
yield (flag.name, getattr(self, flag.attribute))
@classmethod
async def _construct_default(cls: Type[F], ctx: Context) -> F:
self: F = cls.__new__(cls)
flags = cls.__commands_flags__
for flag in flags.values():
if callable(flag.default):
default = await maybe_coroutine(flag.default, ctx)
setattr(self, flag.attribute, default)
else:
setattr(self, flag.attribute, flag.default)
return self
def __repr__(self) -> str:
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
return f'<{self.__class__.__name__} {pairs}>'
@classmethod
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
result: Dict[str, List[str]] = {}
flags = cls.__commands_flags__
aliases = cls.__commands_flag_aliases__
last_position = 0
last_flag: Optional[Flag] = None
case_insensitive = cls.__commands_flag_case_insensitive__
for match in cls.__commands_flag_regex__.finditer(argument):
begin, end = match.span(0)
key = match.group('flag')
if case_insensitive:
key = key.casefold()
if key in aliases:
key = aliases[key]
flag = flags.get(key)
if last_position and last_flag is not None:
value = argument[last_position : begin - 1].lstrip()
if not value:
raise MissingFlagArgument(last_flag)
try:
values = result[last_flag.name]
except KeyError:
result[last_flag.name] = [value]
else:
values.append(value)
last_position = end
last_flag = flag
# Add the remaining string to the last available flag
if last_position and last_flag is not None:
value = argument[last_position:].strip()
if not value:
raise MissingFlagArgument(last_flag)
try:
values = result[last_flag.name]
except KeyError:
result[last_flag.name] = [value]
else:
values.append(value)
# Verification of values will come at a later stage
return result
@classmethod
async def convert(cls: Type[F], ctx: Context, argument: str) -> F:
"""|coro|
The method that actually converters an argument to the flag mapping.
Parameters
----------
cls: Type[:class:`FlagConverter`]
The flag converter class.
ctx: :class:`Context`
The invocation context.
argument: :class:`str`
The argument to convert from.
Raises
--------
FlagError
A flag related parsing error.
CommandError
A command related error.
Returns
--------
:class:`FlagConverter`
The flag converter instance with all flags parsed.
"""
arguments = cls.parse_flags(argument)
flags = cls.__commands_flags__
self: F = cls.__new__(cls)
for name, flag in flags.items():
try:
values = arguments[name]
except KeyError:
if flag.required:
raise MissingRequiredFlag(flag)
else:
if callable(flag.default):
default = await maybe_coroutine(flag.default, ctx)
setattr(self, flag.attribute, default)
else:
setattr(self, flag.attribute, flag.default)
continue
if flag.max_args > 0 and len(values) > flag.max_args:
if flag.override:
values = values[-flag.max_args :]
else:
raise TooManyFlags(flag, values)
# Special case:
if flag.max_args == 1:
value = await convert_flag(ctx, values[0], flag)
setattr(self, flag.attribute, value)
continue
# Another special case, tuple parsing.
# Tuple parsing is basically converting arguments within the flag
# So, given flag: hello 20 as the input and Tuple[str, int] as the type hint
# We would receive ('hello', 20) as the resulting value
# This uses the same whitespace and quoting rules as regular parameters.
values = [await convert_flag(ctx, value, flag) for value in values]
if flag.cast_to_dict:
values = dict(values) # type: ignore
setattr(self, flag.attribute, values)
return self

View File

@@ -27,17 +27,11 @@ import copy
import functools import functools
import inspect import inspect
import re import re
from typing import Optional, TYPE_CHECKING
import discord.utils import discord.utils
from .core import Group, Command from .core import Group, Command
from .errors import CommandError from .errors import CommandError
if TYPE_CHECKING:
from .context import Context
__all__ = ( __all__ = (
'Paginator', 'Paginator',
'HelpCommand', 'HelpCommand',
@@ -66,7 +60,6 @@ __all__ = (
# Type <prefix>help command for more info on a command. # Type <prefix>help command for more info on a command.
# You can also type <prefix>help category for more info on a category. # You can also type <prefix>help category for more info on a category.
class Paginator: class Paginator:
"""A class that aids in paginating code blocks for Discord messages. """A class that aids in paginating code blocks for Discord messages.
@@ -88,7 +81,6 @@ class Paginator:
The character string inserted between lines. e.g. a newline character. The character string inserted between lines. e.g. a newline character.
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'): def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'):
self.prefix = prefix self.prefix = prefix
self.suffix = suffix self.suffix = suffix
@@ -100,7 +92,7 @@ class Paginator:
"""Clears the paginator to have no pages.""" """Clears the paginator to have no pages."""
if self.prefix is not None: if self.prefix is not None:
self._current_page = [self.prefix] self._current_page = [self.prefix]
self._count = len(self.prefix) + self._linesep_len # prefix + newline self._count = len(self.prefix) + self._linesep_len # prefix + newline
else: else:
self._current_page = [] self._current_page = []
self._count = 0 self._count = 0
@@ -158,7 +150,7 @@ class Paginator:
if self.prefix is not None: if self.prefix is not None:
self._current_page = [self.prefix] self._current_page = [self.prefix]
self._count = len(self.prefix) + self._linesep_len # prefix + linesep self._count = len(self.prefix) + self._linesep_len # prefix + linesep
else: else:
self._current_page = [] self._current_page = []
self._count = 0 self._count = 0
@@ -179,12 +171,10 @@ class Paginator:
fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>' fmt = '<Paginator prefix: {0.prefix!r} suffix: {0.suffix!r} linesep: {0.linesep!r} max_size: {0.max_size} count: {0._count}>'
return fmt.format(self) return fmt.format(self)
def _not_overriden(f): def _not_overriden(f):
f.__help_command_not_overriden__ = True f.__help_command_not_overriden__ = True
return f return f
class _HelpCommandImpl(Command): class _HelpCommandImpl(Command):
def __init__(self, inject, *args, **kwargs): def __init__(self, inject, *args, **kwargs):
super().__init__(inject.command_callback, *args, **kwargs) super().__init__(inject.command_callback, *args, **kwargs)
@@ -222,8 +212,8 @@ class _HelpCommandImpl(Command):
def clean_params(self): def clean_params(self):
result = self.params.copy() result = self.params.copy()
try: try:
del result[next(iter(result))] result.popitem(last=False)
except StopIteration: except Exception:
raise ValueError('Missing context parameter') from None raise ValueError('Missing context parameter') from None
else: else:
return result return result
@@ -260,7 +250,6 @@ class _HelpCommandImpl(Command):
cog.walk_commands = cog.walk_commands.__wrapped__ cog.walk_commands = cog.walk_commands.__wrapped__
self.cog = None self.cog = None
class HelpCommand: class HelpCommand:
r"""The base implementation for help command formatting. r"""The base implementation for help command formatting.
@@ -283,9 +272,9 @@ class HelpCommand:
Defaults to ``False``. Defaults to ``False``.
verify_checks: Optional[:class:`bool`] verify_checks: Optional[:class:`bool`]
Specifies if commands should have their :attr:`.Command.checks` called Specifies if commands should have their :attr:`.Command.checks` called
and verified. If ``True``, always calls :attr:`.Command.checks`. and verified. If ``True``, always calls :attr:`.Commands.checks`.
If ``None``, only calls :attr:`.Command.checks` in a guild setting. If ``None``, only calls :attr:`.Commands.checks` in a guild setting.
If ``False``, never calls :attr:`.Command.checks`. Defaults to ``True``. If ``False``, never calls :attr:`.Commands.checks`. Defaults to ``True``.
.. versionchanged:: 1.7 .. versionchanged:: 1.7
command_attrs: :class:`dict` command_attrs: :class:`dict`
@@ -299,7 +288,7 @@ class HelpCommand:
'@everyone': '@\u200beveryone', '@everyone': '@\u200beveryone',
'@here': '@\u200bhere', '@here': '@\u200bhere',
r'<@!?[0-9]{17,22}>': '@deleted-user', r'<@!?[0-9]{17,22}>': '@deleted-user',
r'<@&[0-9]{17,22}>': '@deleted-role', r'<@&[0-9]{17,22}>': '@deleted-role'
} }
MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys()))
@@ -316,7 +305,10 @@ class HelpCommand:
# The keys can be safely copied as-is since they're 99.99% certain of being # The keys can be safely copied as-is since they're 99.99% certain of being
# string keys # string keys
deepcopy = copy.deepcopy deepcopy = copy.deepcopy
self.__original_kwargs__ = {k: deepcopy(v) for k, v in kwargs.items()} self.__original_kwargs__ = {
k: deepcopy(v)
for k, v in kwargs.items()
}
self.__original_args__ = deepcopy(args) self.__original_args__ = deepcopy(args)
return self return self
@@ -326,7 +318,7 @@ class HelpCommand:
self.command_attrs = attrs = options.pop('command_attrs', {}) self.command_attrs = attrs = options.pop('command_attrs', {})
attrs.setdefault('name', 'help') attrs.setdefault('name', 'help')
attrs.setdefault('help', 'Shows this message') attrs.setdefault('help', 'Shows this message')
self.context: Context = discord.utils.MISSING self.context = None
self._command_impl = _HelpCommandImpl(self, **self.command_attrs) self._command_impl = _HelpCommandImpl(self, **self.command_attrs)
def copy(self): def copy(self):
@@ -377,10 +369,25 @@ class HelpCommand:
def get_bot_mapping(self): def get_bot_mapping(self):
"""Retrieves the bot mapping passed to :meth:`send_bot_help`.""" """Retrieves the bot mapping passed to :meth:`send_bot_help`."""
bot = self.context.bot bot = self.context.bot
mapping = {cog: cog.get_commands() for cog in bot.cogs.values()} mapping = {
cog: cog.get_commands()
for cog in bot.cogs.values()
}
mapping[None] = [c for c in bot.commands if c.cog is None] mapping[None] = [c for c in bot.commands if c.cog is None]
return mapping return mapping
@property
def clean_prefix(self):
""":class:`str`: The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``."""
user = self.context.guild.me if self.context.guild else self.context.bot.user
# this breaks if the prefix mention is not the bot itself but I
# consider this to be an *incredibly* strange use case. I'd rather go
# for this common use case rather than waste performance for the
# odd one.
pattern = re.compile(fr"<@!?{user.id}>")
display_name = user.display_name.replace('\\', r'\\')
return pattern.sub('@' + display_name, self.context.prefix)
@property @property
def invoked_with(self): def invoked_with(self):
"""Similar to :attr:`Context.invoked_with` except properly handles """Similar to :attr:`Context.invoked_with` except properly handles
@@ -435,7 +442,7 @@ class HelpCommand:
else: else:
alias = command.name if not parent_sig else parent_sig + ' ' + command.name alias = command.name if not parent_sig else parent_sig + ' ' + command.name
return f'{self.context.clean_prefix}{alias} {command.signature}' return f'{self.clean_prefix}{alias} {command.signature}'
def remove_mentions(self, string): def remove_mentions(self, string):
"""Removes mentions from the string to prevent abuse. """Removes mentions from the string to prevent abuse.
@@ -600,7 +607,10 @@ class HelpCommand:
The maximum width of the commands. The maximum width of the commands.
""" """
as_lengths = (discord.utils._string_width(c.name) for c in commands) as_lengths = (
discord.utils._string_width(c.name)
for c in commands
)
return max(as_lengths, default=0) return max(as_lengths, default=0)
def get_destination(self): def get_destination(self):
@@ -621,7 +631,8 @@ class HelpCommand:
"""|coro| """|coro|
Handles the implementation when an error happens in the help command. Handles the implementation when an error happens in the help command.
For example, the result of :meth:`command_not_found` will be passed here. For example, the result of :meth:`command_not_found` or
:meth:`command_has_no_subcommand_found` will be passed here.
You can override this method to customise the behaviour. You can override this method to customise the behaviour.
@@ -869,7 +880,6 @@ class HelpCommand:
else: else:
return await self.send_command_help(cmd) return await self.send_command_help(cmd)
class DefaultHelpCommand(HelpCommand): class DefaultHelpCommand(HelpCommand):
"""The implementation of the default help command. """The implementation of the default help command.
@@ -924,16 +934,14 @@ class DefaultHelpCommand(HelpCommand):
def shorten_text(self, text): def shorten_text(self, text):
""":class:`str`: Shortens text to fit into the :attr:`width`.""" """:class:`str`: Shortens text to fit into the :attr:`width`."""
if len(text) > self.width: if len(text) > self.width:
return text[:self.width - 3].rstrip() + '...' return text[:self.width - 3] + '...'
return text return text
def get_ending_note(self): def get_ending_note(self):
""":class:`str`: Returns help command's ending note. This is mainly useful to override for i18n purposes.""" """:class:`str`: Returns help command's ending note. This is mainly useful to override for i18n purposes."""
command_name = self.invoked_with command_name = self.invoked_with
return ( return f"Type {self.clean_prefix}{command_name} command for more info on a command.\n" \
f"Type {self.context.clean_prefix}{command_name} command for more info on a command.\n" f"You can also type {self.clean_prefix}{command_name} category for more info on a category."
f"You can also type {self.context.clean_prefix}{command_name} category for more info on a category."
)
def add_indented_commands(self, commands, *, heading, max_size=None): def add_indented_commands(self, commands, *, heading, max_size=None):
"""Indents a list of commands after the specified heading. """Indents a list of commands after the specified heading.
@@ -954,7 +962,7 @@ class DefaultHelpCommand(HelpCommand):
if the list of commands is greater than 0. if the list of commands is greater than 0.
max_size: Optional[:class:`int`] max_size: Optional[:class:`int`]
The max size to use for the gap between indents. The max size to use for the gap between indents.
If unspecified, calls :meth:`~HelpCommand.get_max_size` on the If unspecified, calls :meth:`get_max_size` on the
commands parameter. commands parameter.
""" """
@@ -1022,7 +1030,6 @@ class DefaultHelpCommand(HelpCommand):
self.paginator.add_line(bot.description, empty=True) self.paginator.add_line(bot.description, empty=True)
no_category = f'\u200b{self.no_category}:' no_category = f'\u200b{self.no_category}:'
def get_category(command, *, no_category=no_category): def get_category(command, *, no_category=no_category):
cog = command.cog cog = command.cog
return cog.qualified_name + ':' if cog is not None else no_category return cog.qualified_name + ':' if cog is not None else no_category
@@ -1076,7 +1083,6 @@ class DefaultHelpCommand(HelpCommand):
await self.send_pages() await self.send_pages()
class MinimalHelpCommand(HelpCommand): class MinimalHelpCommand(HelpCommand):
"""An implementation of a help command with minimal output. """An implementation of a help command with minimal output.
@@ -1143,13 +1149,11 @@ class MinimalHelpCommand(HelpCommand):
The help command opening note. The help command opening note.
""" """
command_name = self.invoked_with command_name = self.invoked_with
return ( return "Use `{0}{1} [command]` for more info on a command.\n" \
f"Use `{self.context.clean_prefix}{command_name} [command]` for more info on a command.\n" "You can also use `{0}{1} [category]` for more info on a category.".format(self.clean_prefix, command_name)
f"You can also use `{self.context.clean_prefix}{command_name} [category]` for more info on a category."
)
def get_command_signature(self, command): def get_command_signature(self, command):
return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}' return f'{self.clean_prefix}{command.qualified_name} {command.signature}'
def get_ending_note(self): def get_ending_note(self):
"""Return the help command's ending note. This is mainly useful to override for i18n purposes. """Return the help command's ending note. This is mainly useful to override for i18n purposes.
@@ -1198,7 +1202,7 @@ class MinimalHelpCommand(HelpCommand):
The command to show information of. The command to show information of.
""" """
fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}' fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}'
self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc)) self.paginator.add_line(fmt.format(self.clean_prefix, command.qualified_name, command.short_doc))
def add_aliases_formatting(self, aliases): def add_aliases_formatting(self, aliases):
"""Adds the formatting information on a command's aliases. """Adds the formatting information on a command's aliases.
@@ -1269,7 +1273,6 @@ class MinimalHelpCommand(HelpCommand):
self.paginator.add_line(note, empty=True) self.paginator.add_line(note, empty=True)
no_category = f'\u200b{self.no_category}' no_category = f'\u200b{self.no_category}'
def get_category(command, *, no_category=no_category): def get_category(command, *, no_category=no_category):
cog = command.cog cog = command.cog
return cog.qualified_name if cog is not None else no_category return cog.qualified_name if cog is not None else no_category

View File

@@ -189,4 +189,4 @@ class StringView:
def __repr__(self): def __repr__(self):
return f'<StringView pos: {self.index} prev: {self.previous} end: {self.end} eof: {self.eof}>' return '<StringView pos: {0.index} prev: {0.previous} end: {0.end} eof: {0.eof}>'.format(self)

File diff suppressed because it is too large Load Diff

View File

@@ -22,92 +22,35 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
from typing import (
Any,
Awaitable,
Callable,
Generic,
List,
Optional,
Type,
TypeVar,
Union,
)
import aiohttp import aiohttp
import discord import discord
import inspect import inspect
import logging
import sys import sys
import traceback import traceback
from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
log = logging.getLogger(__name__)
__all__ = ( __all__ = (
'loop', 'loop',
) )
T = TypeVar('T') class Loop:
_func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
class SleepHandle:
__slots__ = ('future', 'loop', 'handle')
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop
self.future = future = loop.create_future()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = loop.call_later(relative_delta, future.set_result, True)
def recalculate(self, dt: datetime.datetime) -> None:
self.handle.cancel()
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
def wait(self) -> asyncio.Future[Any]:
return self.future
def done(self) -> bool:
return self.future.done()
def cancel(self) -> None:
self.handle.cancel()
self.future.cancel()
class Loop(Generic[LF]):
"""A background task helper that abstracts the loop and reconnection logic for you. """A background task helper that abstracts the loop and reconnection logic for you.
The main interface to create this is through :func:`loop`. The main interface to create this is through :func:`loop`.
""" """
def __init__(self, coro, seconds, hours, minutes, count, reconnect, loop):
def __init__( self.coro = coro
self, self.reconnect = reconnect
coro: LF, self.loop = loop
seconds: float, self.count = count
hours: float,
minutes: float,
time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int],
reconnect: bool,
loop: asyncio.AbstractEventLoop,
) -> None:
self.coro: LF = coro
self.reconnect: bool = reconnect
self.loop: asyncio.AbstractEventLoop = loop
self.count: Optional[int] = count
self._current_loop = 0 self._current_loop = 0
self._handle: SleepHandle = MISSING self._task = None
self._task: asyncio.Task[None] = MISSING
self._injected = None self._injected = None
self._valid_exception = ( self._valid_exception = (
OSError, OSError,
@@ -126,15 +69,15 @@ class Loop(Generic[LF]):
if self.count is not None and self.count <= 0: if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.') raise ValueError('count must be greater than 0 or None.')
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) self.change_interval(seconds=seconds, minutes=minutes, hours=hours)
self._last_iteration_failed = False self._last_iteration_failed = False
self._last_iteration: datetime.datetime = MISSING self._last_iteration = None
self._next_iteration = None self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro): if not inspect.iscoroutinefunction(self.coro):
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.') raise TypeError('Expected coroutine function, not {0.__name__!r}.'.format(type(self.coro)))
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: async def _call_loop_function(self, name, *args, **kwargs):
coro = getattr(self, '_' + name) coro = getattr(self, '_' + name)
if coro is None: if coro is None:
return return
@@ -144,22 +87,14 @@ class Loop(Generic[LF]):
else: else:
await coro(*args, **kwargs) await coro(*args, **kwargs)
def _try_sleep_until(self, dt: datetime.datetime): async def _loop(self, *args, **kwargs):
self._handle = SleepHandle(dt=dt, loop=self.loop)
return self._handle.wait()
async def _loop(self, *args: Any, **kwargs: Any) -> None:
backoff = ExponentialBackoff() backoff = ExponentialBackoff()
await self._call_loop_function('before_loop') await self._call_loop_function('before_loop')
sleep_until = discord.utils.sleep_until
self._last_iteration_failed = False self._last_iteration_failed = False
if self._time is not MISSING: self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
# the time index should be prepared every time the internal loop is started
self._prepare_time_index()
self._next_iteration = self._get_next_sleep_time()
else:
self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
try: try:
await self._try_sleep_until(self._next_iteration) await asyncio.sleep(0) # allows canceling in before_loop
while True: while True:
if not self._last_iteration_failed: if not self._last_iteration_failed:
self._last_iteration = self._next_iteration self._last_iteration = self._next_iteration
@@ -167,27 +102,22 @@ class Loop(Generic[LF]):
try: try:
await self.coro(*args, **kwargs) await self.coro(*args, **kwargs)
self._last_iteration_failed = False self._last_iteration_failed = False
now = datetime.datetime.now(datetime.timezone.utc)
if now > self._next_iteration:
self._next_iteration = now
except self._valid_exception: except self._valid_exception:
self._last_iteration_failed = True self._last_iteration_failed = True
if not self.reconnect: if not self.reconnect:
raise raise
await asyncio.sleep(backoff.delay()) await asyncio.sleep(backoff.delay())
else: else:
await self._try_sleep_until(self._next_iteration) await sleep_until(self._next_iteration)
if self._stop_next_iteration: if self._stop_next_iteration:
return return
now = datetime.datetime.now(datetime.timezone.utc)
if now > self._next_iteration:
self._next_iteration = now
if self._time is not MISSING:
self._prepare_time_index(now)
self._current_loop += 1 self._current_loop += 1
if self._current_loop == self.count: if self._current_loop == self.count:
break break
except asyncio.CancelledError: except asyncio.CancelledError:
self._is_being_cancelled = True self._is_being_cancelled = True
raise raise
@@ -197,26 +127,17 @@ class Loop(Generic[LF]):
raise exc raise exc
finally: finally:
await self._call_loop_function('after_loop') await self._call_loop_function('after_loop')
self._handle.cancel()
self._is_being_cancelled = False self._is_being_cancelled = False
self._current_loop = 0 self._current_loop = 0
self._stop_next_iteration = False self._stop_next_iteration = False
self._has_failed = False self._has_failed = False
def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]: def __get__(self, obj, objtype):
if obj is None: if obj is None:
return self return self
copy: Loop[LF] = Loop( copy = Loop(self.coro, seconds=self.seconds, hours=self.hours, minutes=self.minutes,
self.coro, count=self.count, reconnect=self.reconnect, loop=self.loop)
seconds=self._seconds,
hours=self._hours,
minutes=self._minutes,
time=self._time,
count=self.count,
reconnect=self.reconnect,
loop=self.loop,
)
copy._injected = obj copy._injected = obj
copy._before_loop = self._before_loop copy._before_loop = self._before_loop
copy._after_loop = self._after_loop copy._after_loop = self._after_loop
@@ -225,63 +146,23 @@ class Loop(Generic[LF]):
return copy return copy
@property @property
def seconds(self) -> Optional[float]: def current_loop(self):
"""Optional[:class:`float`]: Read-only value for the number of seconds
between each iteration. ``None`` if an explicit ``time`` value was passed instead.
.. versionadded:: 2.0
"""
if self._seconds is not MISSING:
return self._seconds
@property
def minutes(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of minutes
between each iteration. ``None`` if an explicit ``time`` value was passed instead.
.. versionadded:: 2.0
"""
if self._minutes is not MISSING:
return self._minutes
@property
def hours(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of hours
between each iteration. ``None`` if an explicit ``time`` value was passed instead.
.. versionadded:: 2.0
"""
if self._hours is not MISSING:
return self._hours
@property
def time(self) -> Optional[List[datetime.time]]:
"""Optional[List[:class:`datetime.time`]]: Read-only list for the exact times this loop runs at.
``None`` if relative times were passed instead.
.. versionadded:: 2.0
"""
if self._time is not MISSING:
return self._time.copy()
@property
def current_loop(self) -> int:
""":class:`int`: The current iteration of the loop.""" """:class:`int`: The current iteration of the loop."""
return self._current_loop return self._current_loop
@property @property
def next_iteration(self) -> Optional[datetime.datetime]: def next_iteration(self):
"""Optional[:class:`datetime.datetime`]: When the next iteration of the loop will occur. """Optional[:class:`datetime.datetime`]: When the next iteration of the loop will occur.
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
if self._task is MISSING: if self._task is None:
return None return None
elif self._task and self._task.done() or self._stop_next_iteration: elif self._task and self._task.done() or self._stop_next_iteration:
return None return None
return self._next_iteration return self._next_iteration
async def __call__(self, *args: Any, **kwargs: Any) -> Any: async def __call__(self, *args, **kwargs):
r"""|coro| r"""|coro|
Calls the internal callback that the task holds. Calls the internal callback that the task holds.
@@ -301,7 +182,7 @@ class Loop(Generic[LF]):
return await self.coro(*args, **kwargs) return await self.coro(*args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: def start(self, *args, **kwargs):
r"""Starts the internal task in the event loop. r"""Starts the internal task in the event loop.
Parameters Parameters
@@ -322,19 +203,19 @@ class Loop(Generic[LF]):
The task that has been created. The task that has been created.
""" """
if self._task is not MISSING and not self._task.done(): if self._task is not None and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.') raise RuntimeError('Task is already launched and is not completed.')
if self._injected is not None: if self._injected is not None:
args = (self._injected, *args) args = (self._injected, *args)
if self.loop is MISSING: if self.loop is None:
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs)) self._task = self.loop.create_task(self._loop(*args, **kwargs))
return self._task return self._task
def stop(self) -> None: def stop(self):
r"""Gracefully stops the task from running. r"""Gracefully stops the task from running.
Unlike :meth:`cancel`\, this allows the task to finish its Unlike :meth:`cancel`\, this allows the task to finish its
@@ -352,18 +233,18 @@ class Loop(Generic[LF]):
.. versionadded:: 1.2 .. versionadded:: 1.2
""" """
if self._task is not MISSING and not self._task.done(): if self._task and not self._task.done():
self._stop_next_iteration = True self._stop_next_iteration = True
def _can_be_cancelled(self) -> bool: def _can_be_cancelled(self):
return bool(not self._is_being_cancelled and self._task and not self._task.done()) return not self._is_being_cancelled and self._task and not self._task.done()
def cancel(self) -> None: def cancel(self):
"""Cancels the internal task, if it is running.""" """Cancels the internal task, if it is running."""
if self._can_be_cancelled(): if self._can_be_cancelled():
self._task.cancel() self._task.cancel()
def restart(self, *args: Any, **kwargs: Any) -> None: def restart(self, *args, **kwargs):
r"""A convenience method to restart the internal task. r"""A convenience method to restart the internal task.
.. note:: .. note::
@@ -374,12 +255,12 @@ class Loop(Generic[LF]):
Parameters Parameters
------------ ------------
\*args \*args
The arguments to use. The arguments to to use.
\*\*kwargs \*\*kwargs
The keyword arguments to use. The keyword arguments to use.
""" """
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None: def restart_when_over(fut, *, args=args, kwargs=kwargs):
self._task.remove_done_callback(restart_when_over) self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs) self.start(*args, **kwargs)
@@ -387,7 +268,7 @@ class Loop(Generic[LF]):
self._task.add_done_callback(restart_when_over) self._task.add_done_callback(restart_when_over)
self._task.cancel() self._task.cancel()
def add_exception_type(self, *exceptions: Type[BaseException]) -> None: def add_exception_type(self, *exceptions):
r"""Adds exception types to be handled during the reconnect logic. r"""Adds exception types to be handled during the reconnect logic.
By default the exception types handled are those handled by By default the exception types handled are those handled by
@@ -416,7 +297,7 @@ class Loop(Generic[LF]):
self._valid_exception = (*self._valid_exception, *exceptions) self._valid_exception = (*self._valid_exception, *exceptions)
def clear_exception_types(self) -> None: def clear_exception_types(self):
"""Removes all exception types that are handled. """Removes all exception types that are handled.
.. note:: .. note::
@@ -425,7 +306,7 @@ class Loop(Generic[LF]):
""" """
self._valid_exception = tuple() self._valid_exception = tuple()
def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool: def remove_exception_type(self, *exceptions):
r"""Removes exception types from being handled during the reconnect logic. r"""Removes exception types from being handled during the reconnect logic.
Parameters Parameters
@@ -442,34 +323,34 @@ class Loop(Generic[LF]):
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
return len(self._valid_exception) == old_length - len(exceptions) return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self) -> Optional[asyncio.Task[None]]: def get_task(self):
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
return self._task if self._task is not MISSING else None return self._task
def is_being_cancelled(self) -> bool: def is_being_cancelled(self):
"""Whether the task is being cancelled.""" """Whether the task is being cancelled."""
return self._is_being_cancelled return self._is_being_cancelled
def failed(self) -> bool: def failed(self):
""":class:`bool`: Whether the internal task has failed. """:class:`bool`: Whether the internal task has failed.
.. versionadded:: 1.2 .. versionadded:: 1.2
""" """
return self._has_failed return self._has_failed
def is_running(self) -> bool: def is_running(self):
""":class:`bool`: Check if the task is currently running. """:class:`bool`: Check if the task is currently running.
.. versionadded:: 1.4 .. versionadded:: 1.4
""" """
return not bool(self._task.done()) if self._task is not MISSING else False return not bool(self._task.done()) if self._task else False
async def _error(self, *args: Any) -> None: async def _error(self, *args):
exception: Exception = args[-1] exception = args[-1]
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr) print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
def before_loop(self, coro: FT) -> FT: def before_loop(self, coro):
"""A decorator that registers a coroutine to be called before the loop starts running. """A decorator that registers a coroutine to be called before the loop starts running.
This is useful if you want to wait for some bot state before the loop starts, This is useful if you want to wait for some bot state before the loop starts,
@@ -489,12 +370,12 @@ class Loop(Generic[LF]):
""" """
if not inspect.iscoroutinefunction(coro): if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') raise TypeError('Expected coroutine function, received {0.__name__!r}.'.format(type(coro)))
self._before_loop = coro self._before_loop = coro
return coro return coro
def after_loop(self, coro: FT) -> FT: def after_loop(self, coro):
"""A decorator that register a coroutine to be called after the loop finished running. """A decorator that register a coroutine to be called after the loop finished running.
The coroutine must take no arguments (except ``self`` in a class context). The coroutine must take no arguments (except ``self`` in a class context).
@@ -517,12 +398,12 @@ class Loop(Generic[LF]):
""" """
if not inspect.iscoroutinefunction(coro): if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') raise TypeError('Expected coroutine function, received {0.__name__!r}.'.format(type(coro)))
self._after_loop = coro self._after_loop = coro
return coro return coro
def error(self, coro: ET) -> ET: def error(self, coro):
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception. """A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
The coroutine must take only one argument the exception raised (except ``self`` in a class context). The coroutine must take only one argument the exception raised (except ``self`` in a class context).
@@ -543,90 +424,22 @@ class Loop(Generic[LF]):
The function was not a coroutine. The function was not a coroutine.
""" """
if not inspect.iscoroutinefunction(coro): if not inspect.iscoroutinefunction(coro):
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') raise TypeError('Expected coroutine function, received {0.__name__!r}.'.format(type(coro)))
self._error = coro # type: ignore self._error = coro
return coro return coro
def _get_next_sleep_time(self) -> datetime.datetime: def _get_next_sleep_time(self):
if self._sleep is not MISSING: return self._last_iteration + datetime.timedelta(seconds=self._sleep)
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
if self._time_index >= len(self._time): def change_interval(self, *, seconds=0, minutes=0, hours=0):
self._time_index = 0
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
)
next_time = self._time[self._time_index]
if self._current_loop == 0:
self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
next_date = self._last_iteration
if self._time_index == 0:
# we can assume that the earliest time should be scheduled for "tomorrow"
next_date += datetime.timedelta(days=1)
self._time_index += 1
return datetime.datetime.combine(next_date, next_time)
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
# now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from
# pre-condition: self._time is set
time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
for idx, time in enumerate(self._time):
if time >= time_now:
self._time_index = idx
break
else:
self._time_index = 0
def _get_time_parameter(
self,
time: Union[datetime.time, Sequence[datetime.time]],
*,
dt: Type[datetime.time] = datetime.time,
utc: datetime.timezone = datetime.timezone.utc,
) -> List[datetime.time]:
if isinstance(time, dt):
inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [inner]
if not isinstance(time, Sequence):
raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
)
if not time:
raise ValueError('time parameter must not be an empty sequence.')
ret: List[datetime.time] = []
for index, t in enumerate(time):
if not isinstance(t, dt):
raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
ret = sorted(set(ret)) # de-dupe and sort times
return ret
def change_interval(
self,
*,
seconds: float = 0,
minutes: float = 0,
hours: float = 0,
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
) -> None:
"""Changes the interval for the sleep time. """Changes the interval for the sleep time.
.. note::
This only applies on the next loop iteration. If it is desirable for the change of interval
to be applied right away, cancel the task with :meth:`cancel`.
.. versionadded:: 1.2 .. versionadded:: 1.2
Parameters Parameters
@@ -637,66 +450,23 @@ class Loop(Generic[LF]):
The number of minutes between every iteration. The number of minutes between every iteration.
hours: :class:`float` hours: :class:`float`
The number of hours between every iteration. The number of hours between every iteration.
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
The exact times to run this loop at. Either a non-empty list or a single
value of :class:`datetime.time` should be passed.
This cannot be used in conjunction with the relative time parameters.
.. versionadded:: 2.0
.. note::
Duplicate times will be ignored, and only run once.
Raises Raises
------- -------
ValueError ValueError
An invalid value was given. An invalid value was given.
TypeError
An invalid value for the ``time`` parameter was passed, or the
``time`` parameter was passed in conjunction with relative time parameters.
""" """
if time is MISSING: sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
seconds = seconds or 0 if sleep < 0:
minutes = minutes or 0 raise ValueError('Total number of seconds cannot be less than zero.')
hours = hours or 0
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0:
raise ValueError('Total number of seconds cannot be less than zero.')
self._sleep = sleep self._sleep = sleep
self._seconds = float(seconds) self.seconds = seconds
self._hours = float(hours) self.hours = hours
self._minutes = float(minutes) self.minutes = minutes
self._time: List[datetime.time] = MISSING
else:
if any((seconds, minutes, hours)):
raise TypeError('Cannot mix explicit time with relative time')
self._time = self._get_time_parameter(time)
self._sleep = self._seconds = self._minutes = self._hours = MISSING
if self.is_running(): def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None):
if self._time is not MISSING:
# prepare the next time index starting from after the last iteration
self._prepare_time_index(now=self._last_iteration)
self._next_iteration = self._get_next_sleep_time()
if not self._handle.done():
# the loop is sleeping, recalculate based on new interval
self._handle.recalculate(self._next_iteration)
def loop(
*,
seconds: float = MISSING,
minutes: float = MISSING,
hours: float = MISSING,
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None,
reconnect: bool = True,
loop: asyncio.AbstractEventLoop = MISSING,
) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with """A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`. optional reconnect logic. The decorator returns a :class:`Loop`.
@@ -708,19 +478,6 @@ def loop(
The number of minutes between every iteration. The number of minutes between every iteration.
hours: :class:`float` hours: :class:`float`
The number of hours between every iteration. The number of hours between every iteration.
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
The exact times to run this loop at. Either a non-empty list or a single
value of :class:`datetime.time` should be passed. Timezones are supported.
If no timezone is given for the times, it is assumed to represent UTC time.
This cannot be used in conjunction with the relative time parameters.
.. note::
Duplicate times will be ignored, and only run once.
.. versionadded:: 2.0
count: Optional[:class:`int`] count: Optional[:class:`int`]
The number of loops to do, ``None`` if it should be an The number of loops to do, ``None`` if it should be an
infinite loop. infinite loop.
@@ -737,20 +494,16 @@ def loop(
ValueError ValueError
An invalid value was given. An invalid value was given.
TypeError TypeError
The function was not a coroutine, an invalid value for the ``time`` parameter was passed, The function was not a coroutine.
or ``time`` parameter was passed in conjunction with relative time parameters.
""" """
def decorator(func):
def decorator(func: LF) -> Loop[LF]: kwargs = {
return Loop[LF]( 'seconds': seconds,
func, 'minutes': minutes,
seconds=seconds, 'hours': hours,
minutes=minutes, 'count': count,
hours=hours, 'reconnect': reconnect,
count=count, 'loop': loop
time=time, }
reconnect=reconnect, return Loop(func, **kwargs)
loop=loop,
)
return decorator return decorator

View File

@@ -22,17 +22,13 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations import os.path
from typing import Optional, TYPE_CHECKING, Union
import os
import io import io
__all__ = ( __all__ = (
'File', 'File',
) )
class File: class File:
r"""A parameter object used for :meth:`abc.Messageable.send` r"""A parameter object used for :meth:`abc.Messageable.send`
for sending file objects. for sending file objects.
@@ -44,7 +40,7 @@ class File:
Attributes Attributes
----------- -----------
fp: Union[:class:`os.PathLike`, :class:`io.BufferedIOBase`] fp: Union[:class:`str`, :class:`io.BufferedIOBase`]
A file-like object opened in binary mode and read mode A file-like object opened in binary mode and read mode
or a filename representing a file in the hard drive to or a filename representing a file in the hard drive to
open. open.
@@ -66,18 +62,9 @@ class File:
__slots__ = ('fp', 'filename', 'spoiler', '_original_pos', '_owner', '_closer') __slots__ = ('fp', 'filename', 'spoiler', '_original_pos', '_owner', '_closer')
if TYPE_CHECKING: def __init__(self, fp, filename=None, *, spoiler=False):
fp: io.BufferedIOBase self.fp = fp
filename: Optional[str]
spoiler: bool
def __init__(
self,
fp: Union[str, bytes, os.PathLike, io.BufferedIOBase],
filename: Optional[str] = None,
*,
spoiler: bool = False,
):
if isinstance(fp, io.IOBase): if isinstance(fp, io.IOBase):
if not (fp.seekable() and fp.readable()): if not (fp.seekable() and fp.readable()):
raise ValueError(f'File buffer {fp!r} must be seekable and readable') raise ValueError(f'File buffer {fp!r} must be seekable and readable')
@@ -109,7 +96,7 @@ class File:
self.spoiler = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_')) self.spoiler = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_'))
def reset(self, *, seek: Union[int, bool] = True) -> None: def reset(self, *, seek=True):
# The `seek` parameter is needed because # The `seek` parameter is needed because
# the retry-loop is iterated over multiple times # the retry-loop is iterated over multiple times
# starting from 0, as an implementation quirk # starting from 0, as an implementation quirk
@@ -121,7 +108,7 @@ class File:
if seek: if seek:
self.fp.seek(self._original_pos) self.fp.seek(self._original_pos)
def close(self) -> None: def close(self):
self.fp.close = self._closer self.fp.close = self._closer
if self._owner: if self._owner:
self._closer() self._closer()

View File

@@ -34,14 +34,12 @@ __all__ = (
'PublicUserFlags', 'PublicUserFlags',
'Intents', 'Intents',
'MemberCacheFlags', 'MemberCacheFlags',
'ApplicationFlags',
) )
FV = TypeVar('FV', bound='flag_value') FV = TypeVar('FV', bound='flag_value')
BF = TypeVar('BF', bound='BaseFlags') BF = TypeVar('BF', bound='BaseFlags')
class flag_value(Generic[BF]):
class flag_value:
def __init__(self, func: Callable[[Any], int]): def __init__(self, func: Callable[[Any], int]):
self.flag = func(None) self.flag = func(None)
self.__doc__ = func.__doc__ self.__doc__ = func.__doc__
@@ -65,20 +63,16 @@ class flag_value:
def __repr__(self): def __repr__(self):
return f'<flag_value flag={self.flag!r}>' return f'<flag_value flag={self.flag!r}>'
class alias_flag_value(flag_value): class alias_flag_value(flag_value):
pass pass
def fill_with_flags(*, inverted: bool = False): def fill_with_flags(*, inverted: bool = False):
def decorator(cls: Type[BF]): def decorator(cls: Type[BF]):
# fmt: off
cls.VALID_FLAGS = { cls.VALID_FLAGS = {
name: value.flag name: value.flag
for name, value in cls.__dict__.items() for name, value in cls.__dict__.items()
if isinstance(value, flag_value) if isinstance(value, flag_value)
} }
# fmt: on
if inverted: if inverted:
max_bits = max(cls.VALID_FLAGS.values()).bit_length() max_bits = max(cls.VALID_FLAGS.values()).bit_length()
@@ -87,10 +81,8 @@ def fill_with_flags(*, inverted: bool = False):
cls.DEFAULT_VALUE = 0 cls.DEFAULT_VALUE = 0
return cls return cls
return decorator return decorator
# n.b. flags must inherit from this and use the decorator above # n.b. flags must inherit from this and use the decorator above
class BaseFlags: class BaseFlags:
VALID_FLAGS: ClassVar[Dict[str, int]] VALID_FLAGS: ClassVar[Dict[str, int]]
@@ -144,7 +136,6 @@ class BaseFlags:
else: else:
raise TypeError(f'Value to set for {self.__class__.__name__} must be a bool.') raise TypeError(f'Value to set for {self.__class__.__name__} must be a bool.')
@fill_with_flags(inverted=True) @fill_with_flags(inverted=True)
class SystemChannelFlags(BaseFlags): class SystemChannelFlags(BaseFlags):
r"""Wraps up a Discord system channel flag value. r"""Wraps up a Discord system channel flag value.
@@ -205,17 +196,9 @@ class SystemChannelFlags(BaseFlags):
@flag_value @flag_value
def premium_subscriptions(self): def premium_subscriptions(self):
""":class:`bool`: Returns ``True`` if the system channel is used for "Nitro boosting" notifications.""" """:class:`bool`: Returns ``True`` if the system channel is used for Nitro boosting notifications."""
return 2 return 2
@flag_value
def guild_reminder_notifications(self):
""":class:`bool`: Returns ``True`` if the system channel is used for server setup helpful tips notifications.
.. versionadded:: 2.0
"""
return 4
@fill_with_flags() @fill_with_flags()
class MessageFlags(BaseFlags): class MessageFlags(BaseFlags):
@@ -279,23 +262,6 @@ class MessageFlags(BaseFlags):
""" """
return 16 return 16
@flag_value
def has_thread(self):
""":class:`bool`: Returns ``True`` if the source message is associated with a thread.
.. versionadded:: 2.0
"""
return 32
@flag_value
def ephemeral(self):
""":class:`bool`: Returns ``True`` if the source message is ephemeral.
.. versionadded:: 2.0
"""
return 64
@fill_with_flags() @fill_with_flags()
class PublicUserFlags(BaseFlags): class PublicUserFlags(BaseFlags):
r"""Wraps up the Discord User Public flags. r"""Wraps up the Discord User Public flags.
@@ -402,14 +368,6 @@ class PublicUserFlags(BaseFlags):
""" """
return UserFlags.verified_bot_developer.value return UserFlags.verified_bot_developer.value
@flag_value
def discord_certified_moderator(self):
""":class:`bool`: Returns ``True`` if the user is a Discord Certified Moderator.
.. versionadded:: 2.0
"""
return UserFlags.discord_certified_moderator.value
def all(self) -> List[UserFlags]: def all(self) -> List[UserFlags]:
"""List[:class:`UserFlags`]: Returns all public flags the user has.""" """List[:class:`UserFlags`]: Returns all public flags the user has."""
return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)] return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)]
@@ -464,6 +422,22 @@ class Intents(BaseFlags):
raise TypeError(f'{key!r} is not a valid flag name.') raise TypeError(f'{key!r} is not a valid flag name.')
setattr(self, key, value) setattr(self, key, value)
@classmethod
def from_list(cls, intents_list):
"""A factory method that creates a :class:`Intents` with everything enabled
that has been passed in the list.
.. versionadded:: 1.5.0.1"""
for item in intents_list:
if item not in cls.VALID_FLAGS.keys():
intents_list.remove(item)
self = cls.none()
for item in intents_list:
setattr(self, item, True)
return self
@classmethod @classmethod
def all(cls: Type[Intents]) -> Intents: def all(cls: Type[Intents]) -> Intents:
"""A factory method that creates a :class:`Intents` with everything enabled.""" """A factory method that creates a :class:`Intents` with everything enabled."""
@@ -524,13 +498,12 @@ class Intents(BaseFlags):
- :func:`on_member_join` - :func:`on_member_join`
- :func:`on_member_remove` - :func:`on_member_remove`
- :func:`on_member_update` - :func:`on_member_update` (nickname, roles)
- :func:`on_user_update` - :func:`on_user_update`
This also corresponds to the following attributes and classes in terms of cache: This also corresponds to the following attributes and classes in terms of cache:
- :meth:`Client.get_all_members` - :meth:`Client.get_all_members`
- :meth:`Client.get_user`
- :meth:`Guild.chunk` - :meth:`Guild.chunk`
- :meth:`Guild.fetch_members` - :meth:`Guild.fetch_members`
- :meth:`Guild.get_member` - :meth:`Guild.get_member`
@@ -539,7 +512,7 @@ class Intents(BaseFlags):
- :attr:`Member.nick` - :attr:`Member.nick`
- :attr:`Member.premium_since` - :attr:`Member.premium_since`
- :attr:`User.name` - :attr:`User.name`
- :attr:`User.avatar` - :attr:`User.avatar` (:attr:`User.avatar_url` and :meth:`User.avatar_url_as`)
- :attr:`User.discriminator` - :attr:`User.discriminator`
For more information go to the :ref:`member intent documentation <need_members_intent>`. For more information go to the :ref:`member intent documentation <need_members_intent>`.
@@ -566,34 +539,18 @@ class Intents(BaseFlags):
@flag_value @flag_value
def emojis(self): def emojis(self):
""":class:`bool`: Alias of :attr:`.emojis_and_stickers`. """:class:`bool`: Whether guild emoji related events are enabled.
.. versionchanged:: 2.0
Changed to an alias.
"""
return 1 << 3
@alias_flag_value
def emojis_and_stickers(self):
""":class:`bool`: Whether guild emoji and sticker related events are enabled.
.. versionadded:: 2.0
This corresponds to the following events: This corresponds to the following events:
- :func:`on_guild_emojis_update` - :func:`on_guild_emojis_update`
- :func:`on_guild_stickers_update`
This also corresponds to the following attributes and classes in terms of cache: This also corresponds to the following attributes and classes in terms of cache:
- :class:`Emoji` - :class:`Emoji`
- :class:`GuildSticker`
- :meth:`Client.get_emoji` - :meth:`Client.get_emoji`
- :meth:`Client.get_sticker`
- :meth:`Client.emojis` - :meth:`Client.emojis`
- :meth:`Client.stickers`
- :attr:`Guild.emojis` - :attr:`Guild.emojis`
- :attr:`Guild.stickers`
""" """
return 1 << 3 return 1 << 3
@@ -604,9 +561,6 @@ class Intents(BaseFlags):
This corresponds to the following events: This corresponds to the following events:
- :func:`on_guild_integrations_update` - :func:`on_guild_integrations_update`
- :func:`on_integration_create`
- :func:`on_integration_update`
- :func:`on_raw_integration_delete`
This does not correspond to any attributes or classes in the library in terms of cache. This does not correspond to any attributes or classes in the library in terms of cache.
""" """
@@ -650,20 +604,17 @@ class Intents(BaseFlags):
- :attr:`VoiceChannel.members` - :attr:`VoiceChannel.members`
- :attr:`VoiceChannel.voice_states` - :attr:`VoiceChannel.voice_states`
- :attr:`Member.voice` - :attr:`Member.voice`
.. note::
This intent is required to connect to voice.
""" """
return 1 << 7 return 1 << 7
@flag_value @flag_value
def presences(self): def presences(self):
""":class:`bool`: Whether guild presence related events are enabled. """:class:`bool`: Whether guild presence related events are enabled.
This corresponds to the following events: This corresponds to the following events:
- :func:`on_presence_update` - :func:`on_member_update` (activities, status)
This also corresponds to the following attributes and classes in terms of cache: This also corresponds to the following attributes and classes in terms of cache:
@@ -693,6 +644,7 @@ class Intents(BaseFlags):
- :func:`on_message_delete` (both guilds and DMs) - :func:`on_message_delete` (both guilds and DMs)
- :func:`on_raw_message_delete` (both guilds and DMs) - :func:`on_raw_message_delete` (both guilds and DMs)
- :func:`on_raw_message_edit` (both guilds and DMs) - :func:`on_raw_message_edit` (both guilds and DMs)
- :func:`on_private_channel_create`
This also corresponds to the following attributes and classes in terms of cache: This also corresponds to the following attributes and classes in terms of cache:
@@ -747,6 +699,7 @@ class Intents(BaseFlags):
- :func:`on_message_delete` (only for DMs) - :func:`on_message_delete` (only for DMs)
- :func:`on_raw_message_delete` (only for DMs) - :func:`on_raw_message_delete` (only for DMs)
- :func:`on_raw_message_edit` (only for DMs) - :func:`on_raw_message_edit` (only for DMs)
- :func:`on_private_channel_create`
This also corresponds to the following attributes and classes in terms of cache: This also corresponds to the following attributes and classes in terms of cache:
@@ -866,7 +819,6 @@ class Intents(BaseFlags):
""" """
return 1 << 14 return 1 << 14
@fill_with_flags() @fill_with_flags()
class MemberCacheFlags(BaseFlags): class MemberCacheFlags(BaseFlags):
"""Controls the library's cache policy when it comes to members. """Controls the library's cache policy when it comes to members.
@@ -940,6 +892,17 @@ class MemberCacheFlags(BaseFlags):
def _empty(self): def _empty(self):
return self.value == self.DEFAULT_VALUE return self.value == self.DEFAULT_VALUE
@flag_value
def online(self):
""":class:`bool`: Whether to cache members with a status.
For example, members that are part of the initial ``GUILD_CREATE``
or become online at a later point. This requires :attr:`Intents.presences`.
Members that go offline are no longer cached.
"""
return 1
@flag_value @flag_value
def voice(self): def voice(self):
""":class:`bool`: Whether to cache members that are in voice. """:class:`bool`: Whether to cache members that are in voice.
@@ -948,7 +911,7 @@ class MemberCacheFlags(BaseFlags):
Members that leave voice are no longer cached. Members that leave voice are no longer cached.
""" """
return 1 return 2
@flag_value @flag_value
def joined(self): def joined(self):
@@ -959,7 +922,7 @@ class MemberCacheFlags(BaseFlags):
Members that leave the guild are no longer cached. Members that leave the guild are no longer cached.
""" """
return 2 return 4
@classmethod @classmethod
def from_intents(cls: Type[MemberCacheFlags], intents: Intents) -> MemberCacheFlags: def from_intents(cls: Type[MemberCacheFlags], intents: Intents) -> MemberCacheFlags:
@@ -980,89 +943,35 @@ class MemberCacheFlags(BaseFlags):
self = cls.none() self = cls.none()
if intents.members: if intents.members:
self.joined = True self.joined = True
if intents.presences:
self.online = True
if intents.voice_states: if intents.voice_states:
self.voice = True self.voice = True
if not self.joined and self.online and self.voice:
self.voice = False
return self return self
def _verify_intents(self, intents: Intents): def _verify_intents(self, intents: Intents):
if self.online and not intents.presences:
raise ValueError('MemberCacheFlags.online requires Intents.presences enabled')
if self.voice and not intents.voice_states: if self.voice and not intents.voice_states:
raise ValueError('MemberCacheFlags.voice requires Intents.voice_states') raise ValueError('MemberCacheFlags.voice requires Intents.voice_states')
if self.joined and not intents.members: if self.joined and not intents.members:
raise ValueError('MemberCacheFlags.joined requires Intents.members') raise ValueError('MemberCacheFlags.joined requires Intents.members')
if not self.joined and self.voice and self.online:
msg = 'Setting both MemberCacheFlags.voice and MemberCacheFlags.online requires MemberCacheFlags.joined ' \
'to properly evict members from the cache.'
raise ValueError(msg)
@property @property
def _voice_only(self): def _voice_only(self):
return self.value == 2
@property
def _online_only(self):
return self.value == 1 return self.value == 1
@fill_with_flags()
class ApplicationFlags(BaseFlags):
r"""Wraps up the Discord Application flags.
.. container:: operations
.. describe:: x == y
Checks if two ApplicationFlags are equal.
.. describe:: x != y
Checks if two ApplicationFlags are not equal.
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Note that aliases are not shown.
.. versionadded:: 2.0
Attributes
-----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
@flag_value
def gateway_presence(self):
""":class:`bool`: Returns ``True`` if the application is verified and is allowed to
receive presence information over the gateway.
"""
return 1 << 12
@flag_value
def gateway_presence_limited(self):
""":class:`bool`: Returns ``True`` if the application is allowed to receive limited
presence information over the gateway.
"""
return 1 << 13
@flag_value
def gateway_guild_members(self):
""":class:`bool`: Returns ``True`` if the application is verified and is allowed to
receive guild members information over the gateway.
"""
return 1 << 14
@flag_value
def gateway_guild_members_limited(self):
""":class:`bool`: Returns ``True`` if the application is allowed to receive limited
guild members information over the gateway.
"""
return 1 << 15
@flag_value
def verification_pending_guild_limit(self):
""":class:`bool`: Returns ``True`` if the application is currently pending verification
and has hit the guild limit.
"""
return 1 << 16
@flag_value
def embedded(self):
""":class:`bool`: Returns ``True`` if the application is embedded within the Discord client."""
return 1 << 17

View File

@@ -25,6 +25,7 @@ DEALINGS IN THE SOFTWARE.
import asyncio import asyncio
from collections import namedtuple, deque from collections import namedtuple, deque
import concurrent.futures import concurrent.futures
import json
import logging import logging
import struct import struct
import sys import sys
@@ -40,7 +41,7 @@ from .activity import BaseActivity
from .enums import SpeakingState from .enums import SpeakingState
from .errors import ConnectionClosed, InvalidArgument from .errors import ConnectionClosed, InvalidArgument
_log = logging.getLogger(__name__) log = logging.getLogger(__name__)
__all__ = ( __all__ = (
'DiscordWebSocket', 'DiscordWebSocket',
@@ -101,7 +102,7 @@ class GatewayRatelimiter:
async with self.lock: async with self.lock:
delta = self.get_delay() delta = self.get_delay()
if delta: if delta:
_log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta) log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
await asyncio.sleep(delta) await asyncio.sleep(delta)
@@ -129,20 +130,20 @@ class KeepAliveHandler(threading.Thread):
def run(self): def run(self):
while not self._stop_ev.wait(self.interval): while not self._stop_ev.wait(self.interval):
if self._last_recv + self.heartbeat_timeout < time.perf_counter(): if self._last_recv + self.heartbeat_timeout < time.perf_counter():
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
coro = self.ws.close(4000) coro = self.ws.close(4000)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try: try:
f.result() f.result()
except Exception: except Exception:
_log.exception('An error occurred while stopping the gateway. Ignoring.') log.exception('An error occurred while stopping the gateway. Ignoring.')
finally: finally:
self.stop() self.stop()
return return
data = self.get_payload() data = self.get_payload()
_log.debug(self.msg, self.shard_id, data['d']) log.debug(self.msg, self.shard_id, data['d'])
coro = self.ws.send_heartbeat(data) coro = self.ws.send_heartbeat(data)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try: try:
@@ -161,7 +162,7 @@ class KeepAliveHandler(threading.Thread):
else: else:
stack = ''.join(traceback.format_stack(frame)) stack = ''.join(traceback.format_stack(frame))
msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}' msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}'
_log.warning(msg, self.shard_id, total) log.warning(msg, self.shard_id, total)
except Exception: except Exception:
self.stop() self.stop()
@@ -185,7 +186,7 @@ class KeepAliveHandler(threading.Thread):
self._last_ack = ack_time self._last_ack = ack_time
self.latency = ack_time - self._last_send self.latency = ack_time - self._last_send
if self.latency > 10: if self.latency > 10:
_log.warning(self.behind_msg, self.shard_id, self.latency) log.warning(self.behind_msg, self.shard_id, self.latency)
class VoiceKeepAliveHandler(KeepAliveHandler): class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -293,12 +294,6 @@ class DiscordWebSocket:
def is_ratelimited(self): def is_ratelimited(self):
return self._rate_limiter.is_ratelimited() return self._rate_limiter.is_ratelimited()
def debug_log_receive(self, data, /):
self._dispatch('socket_raw_receive', data)
def log_receive(self, _, /):
pass
@classmethod @classmethod
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
@@ -324,13 +319,9 @@ class DiscordWebSocket:
ws.sequence = sequence ws.sequence = sequence
ws._max_heartbeat_timeout = client._connection.heartbeat_timeout ws._max_heartbeat_timeout = client._connection.heartbeat_timeout
if client._enable_debug_events:
ws.send = ws.debug_send
ws.log_receive = ws.debug_log_receive
client._connection._update_references(ws) client._connection._update_references(ws)
_log.debug('Created websocket connected to %s', gateway) log.debug('Created websocket connected to %s', gateway)
# poll event for OP Hello # poll event for OP Hello
await ws.poll_event() await ws.poll_event()
@@ -382,6 +373,7 @@ class DiscordWebSocket:
}, },
'compress': True, 'compress': True,
'large_threshold': 250, 'large_threshold': 250,
'guild_subscriptions': self._connection.guild_subscriptions,
'v': 3 'v': 3
} }
} }
@@ -403,7 +395,7 @@ class DiscordWebSocket:
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify) await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
await self.send_as_json(payload) await self.send_as_json(payload)
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
async def resume(self): async def resume(self):
"""Sends the RESUME packet.""" """Sends the RESUME packet."""
@@ -417,9 +409,11 @@ class DiscordWebSocket:
} }
await self.send_as_json(payload) await self.send_as_json(payload)
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
async def received_message(self, msg):
self._dispatch('socket_raw_receive', msg)
async def received_message(self, msg, /):
if type(msg) is bytes: if type(msg) is bytes:
self._buffer.extend(msg) self._buffer.extend(msg)
@@ -428,14 +422,10 @@ class DiscordWebSocket:
msg = self._zlib.decompress(self._buffer) msg = self._zlib.decompress(self._buffer)
msg = msg.decode('utf-8') msg = msg.decode('utf-8')
self._buffer = bytearray() self._buffer = bytearray()
msg = json.loads(msg)
self.log_receive(msg) log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
msg = utils._from_json(msg) self._dispatch('socket_response', msg)
_log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
event = msg.get('t')
if event:
self._dispatch('socket_event_type', event)
op = msg.get('op') op = msg.get('op')
data = msg.get('d') data = msg.get('d')
@@ -451,7 +441,7 @@ class DiscordWebSocket:
# "reconnect" can only be handled by the Client # "reconnect" can only be handled by the Client
# so we terminate our connection and raise an # so we terminate our connection and raise an
# internal exception signalling to reconnect. # internal exception signalling to reconnect.
_log.debug('Received RECONNECT opcode.') log.debug('Received RECONNECT opcode.')
await self.close() await self.close()
raise ReconnectWebSocket(self.shard_id) raise ReconnectWebSocket(self.shard_id)
@@ -481,33 +471,35 @@ class DiscordWebSocket:
self.sequence = None self.sequence = None
self.session_id = None self.session_id = None
_log.info('Shard ID %s session has been invalidated.', self.shard_id) log.info('Shard ID %s session has been invalidated.', self.shard_id)
await self.close(code=1000) await self.close(code=1000)
raise ReconnectWebSocket(self.shard_id, resume=False) raise ReconnectWebSocket(self.shard_id, resume=False)
_log.warning('Unknown OP code %s.', op) log.warning('Unknown OP code %s.', op)
return return
event = msg.get('t')
if event == 'READY': if event == 'READY':
self._trace = trace = data.get('_trace', []) self._trace = trace = data.get('_trace', [])
self.sequence = msg['s'] self.sequence = msg['s']
self.session_id = data['session_id'] self.session_id = data['session_id']
# pass back shard ID to ready handler # pass back shard ID to ready handler
data['__shard_id__'] = self.shard_id data['__shard_id__'] = self.shard_id
_log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).', log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).',
self.shard_id, ', '.join(trace), self.session_id) self.shard_id, ', '.join(trace), self.session_id)
elif event == 'RESUMED': elif event == 'RESUMED':
self._trace = trace = data.get('_trace', []) self._trace = trace = data.get('_trace', [])
# pass back the shard ID to the resumed handler # pass back the shard ID to the resumed handler
data['__shard_id__'] = self.shard_id data['__shard_id__'] = self.shard_id
_log.info('Shard ID %s has successfully RESUMED session %s under trace %s.', log.info('Shard ID %s has successfully RESUMED session %s under trace %s.',
self.shard_id, self.session_id, ', '.join(trace)) self.shard_id, self.session_id, ', '.join(trace))
try: try:
func = self._discord_parsers[event] func = self._discord_parsers[event]
except KeyError: except KeyError:
_log.debug('Unknown event %s.', event) log.debug('Unknown event %s.', event)
else: else:
func(data) func(data)
@@ -561,10 +553,10 @@ class DiscordWebSocket:
elif msg.type is aiohttp.WSMsgType.BINARY: elif msg.type is aiohttp.WSMsgType.BINARY:
await self.received_message(msg.data) await self.received_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR: elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg) log.debug('Received %s', msg)
raise msg.data raise msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE): elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
_log.debug('Received %s', msg) log.debug('Received %s', msg)
raise WebSocketClosure raise WebSocketClosure
except (asyncio.TimeoutError, WebSocketClosure) as e: except (asyncio.TimeoutError, WebSocketClosure) as e:
# Ensure the keep alive handler is closed # Ensure the keep alive handler is closed
@@ -573,29 +565,25 @@ class DiscordWebSocket:
self._keep_alive = None self._keep_alive = None
if isinstance(e, asyncio.TimeoutError): if isinstance(e, asyncio.TimeoutError):
_log.info('Timed out receiving packet. Attempting a reconnect.') log.info('Timed out receiving packet. Attempting a reconnect.')
raise ReconnectWebSocket(self.shard_id) from None raise ReconnectWebSocket(self.shard_id) from None
code = self._close_code or self.socket.close_code code = self._close_code or self.socket.close_code
if self._can_handle_close(): if self._can_handle_close():
_log.info('Websocket closed with %s, attempting a reconnect.', code) log.info('Websocket closed with %s, attempting a reconnect.', code)
raise ReconnectWebSocket(self.shard_id) from None raise ReconnectWebSocket(self.shard_id) from None
else: else:
_log.info('Websocket closed with %s, cannot reconnect.', code) log.info('Websocket closed with %s, cannot reconnect.', code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
async def debug_send(self, data, /): async def send(self, data):
await self._rate_limiter.block() await self._rate_limiter.block()
self._dispatch('socket_raw_send', data) self._dispatch('socket_raw_send', data)
await self.socket.send_str(data) await self.socket.send_str(data)
async def send(self, data, /):
await self._rate_limiter.block()
await self.socket.send_str(data)
async def send_as_json(self, data): async def send_as_json(self, data):
try: try:
await self.send(utils._to_json(data)) await self.send(utils.to_json(data))
except RuntimeError as exc: except RuntimeError as exc:
if not self._can_handle_close(): if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
@@ -603,18 +591,16 @@ class DiscordWebSocket:
async def send_heartbeat(self, data): async def send_heartbeat(self, data):
# This bypasses the rate limit handling code since it has a higher priority # This bypasses the rate limit handling code since it has a higher priority
try: try:
await self.socket.send_str(utils._to_json(data)) await self.socket.send_str(utils.to_json(data))
except RuntimeError as exc: except RuntimeError as exc:
if not self._can_handle_close(): if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, since=0.0): async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0):
if activity is not None: if activity is not None:
if not isinstance(activity, BaseActivity): if not isinstance(activity, BaseActivity):
raise InvalidArgument('activity must derive from BaseActivity.') raise InvalidArgument('activity must derive from BaseActivity.')
activity = [activity.to_dict()] activity = activity.to_dict()
else:
activity = []
if status == 'idle': if status == 'idle':
since = int(time.time() * 1000) since = int(time.time() * 1000)
@@ -622,15 +608,15 @@ class DiscordWebSocket:
payload = { payload = {
'op': self.PRESENCE, 'op': self.PRESENCE,
'd': { 'd': {
'activities': activity, 'game': activity,
'afk': False, 'afk': afk,
'since': since, 'since': since,
'status': status 'status': status
} }
} }
sent = utils._to_json(payload) sent = utils.to_json(payload)
_log.debug('Sending "%s" to change status', sent) log.debug('Sending "%s" to change status', sent)
await self.send(sent) await self.send(sent)
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None):
@@ -666,7 +652,7 @@ class DiscordWebSocket:
} }
} }
_log.debug('Updating our voice state to %s.', payload) log.debug('Updating our voice state to %s.', payload)
await self.send_as_json(payload) await self.send_as_json(payload)
async def close(self, code=4000): async def close(self, code=4000):
@@ -721,21 +707,16 @@ class DiscordVoiceWebSocket:
CLIENT_CONNECT = 12 CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13 CLIENT_DISCONNECT = 13
def __init__(self, socket, loop, *, hook=None): def __init__(self, socket, loop):
self.ws = socket self.ws = socket
self.loop = loop self.loop = loop
self._keep_alive = None self._keep_alive = None
self._close_code = None self._close_code = None
self.secret_key = None self.secret_key = None
if hook:
self._hook = hook
async def _hook(self, *args):
pass
async def send_as_json(self, data): async def send_as_json(self, data):
_log.debug('Sending voice websocket frame: %s.', data) log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils._to_json(data)) await self.ws.send_str(utils.to_json(data))
send_heartbeat = send_as_json send_heartbeat = send_as_json
@@ -765,12 +746,12 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
@classmethod @classmethod
async def from_client(cls, client, *, resume=False, hook=None): async def from_client(cls, client, *, resume=False):
"""Creates a voice websocket for the :class:`VoiceClient`.""" """Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4' gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http http = client._state.http
socket = await http.ws_connect(gateway, compress=15) socket = await http.ws_connect(gateway, compress=15)
ws = cls(socket, loop=client.loop, hook=hook) ws = cls(socket, loop=client.loop)
ws.gateway = gateway ws.gateway = gateway
ws._connection = client ws._connection = client
ws._max_heartbeat_timeout = 60.0 ws._max_heartbeat_timeout = 60.0
@@ -820,7 +801,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def received_message(self, msg): async def received_message(self, msg):
_log.debug('Voice websocket frame received: %s', msg) log.debug('Voice websocket frame received: %s', msg)
op = msg['op'] op = msg['op']
data = msg.get('d') data = msg.get('d')
@@ -829,7 +810,7 @@ class DiscordVoiceWebSocket:
elif op == self.HEARTBEAT_ACK: elif op == self.HEARTBEAT_ACK:
self._keep_alive.ack() self._keep_alive.ack()
elif op == self.RESUMED: elif op == self.RESUMED:
_log.info('Voice RESUME succeeded.') log.info('Voice RESUME succeeded.')
elif op == self.SESSION_DESCRIPTION: elif op == self.SESSION_DESCRIPTION:
self._connection.mode = data['mode'] self._connection.mode = data['mode']
await self.load_secret_key(data) await self.load_secret_key(data)
@@ -838,8 +819,6 @@ class DiscordVoiceWebSocket:
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
self._keep_alive.start() self._keep_alive.start()
await self._hook(self, msg)
async def initial_connection(self, data): async def initial_connection(self, data):
state = self._connection state = self._connection
state.ssrc = data['ssrc'] state.ssrc = data['ssrc']
@@ -852,7 +831,7 @@ class DiscordVoiceWebSocket:
struct.pack_into('>I', packet, 4, state.ssrc) struct.pack_into('>I', packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 70) recv = await self.loop.sock_recv(state.socket, 70)
_log.debug('received packet in initial_connection: %s', recv) log.debug('received packet in initial_connection: %s', recv)
# the ip is ascii starting at the 4th byte and ending at the first null # the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4 ip_start = 4
@@ -860,15 +839,15 @@ class DiscordVoiceWebSocket:
state.ip = recv[ip_start:ip_end].decode('ascii') state.ip = recv[ip_start:ip_end].decode('ascii')
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0] state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.ip, state.port) log.debug('detected ip: %s port: %s', state.ip, state.port)
# there *should* always be at least one supported mode (xsalsa20_poly1305) # there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ", ".join(modes)) log.debug('received supported encryption modes: %s', ", ".join(modes))
mode = modes[0] mode = modes[0]
await self.select_protocol(state.ip, state.port, mode) await self.select_protocol(state.ip, state.port, mode)
_log.info('selected the voice protocol for use (%s)', mode) log.info('selected the voice protocol for use (%s)', mode)
@property @property
def latency(self): def latency(self):
@@ -886,7 +865,7 @@ class DiscordVoiceWebSocket:
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data): async def load_secret_key(self, data):
_log.info('received secret key for voice connection') log.info('received secret key for voice connection')
self.secret_key = self._connection.secret_key = data.get('secret_key') self.secret_key = self._connection.secret_key = data.get('secret_key')
await self.speak() await self.speak()
await self.speak(False) await self.speak(False)
@@ -895,12 +874,12 @@ class DiscordVoiceWebSocket:
# This exception is handled up the chain # This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT: if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data)) await self.received_message(json.loads(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR: elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg) log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
_log.debug('Received %s', msg) log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code=1000): async def close(self, code=1000):

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -22,36 +22,17 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import datetime import datetime
from typing import Any, Dict, Optional, TYPE_CHECKING, overload, Type, Tuple from .utils import _get_as_snowflake, get, parse_time
from .utils import _get_as_snowflake, parse_time, MISSING
from .user import User from .user import User
from .errors import InvalidArgument from .errors import InvalidArgument
from .enums import try_enum, ExpireBehaviour from .enums import try_enum, ExpireBehaviour
__all__ = ( __all__ = (
'IntegrationAccount', 'IntegrationAccount',
'IntegrationApplication',
'Integration', 'Integration',
'StreamIntegration',
'BotIntegration',
) )
if TYPE_CHECKING:
from .types.integration import (
IntegrationAccount as IntegrationAccountPayload,
Integration as IntegrationPayload,
StreamIntegration as StreamIntegrationPayload,
BotIntegration as BotIntegrationPayload,
IntegrationType,
IntegrationApplication as IntegrationApplicationPayload,
)
from .guild import Guild
from .role import Role
class IntegrationAccount: class IntegrationAccount:
"""Represents an integration account. """Represents an integration account.
@@ -59,7 +40,7 @@ class IntegrationAccount:
Attributes Attributes
----------- -----------
id: :class:`str` id: :class:`int`
The account ID. The account ID.
name: :class:`str` name: :class:`str`
The account name. The account name.
@@ -67,13 +48,12 @@ class IntegrationAccount:
__slots__ = ('id', 'name') __slots__ = ('id', 'name')
def __init__(self, data: IntegrationAccountPayload) -> None: def __init__(self, **kwargs):
self.id: str = data['id'] self.id = kwargs.pop('id')
self.name: str = data['name'] self.name = kwargs.pop('name')
def __repr__(self) -> str:
return f'<IntegrationAccount id={self.id} name={self.name!r}>'
def __repr__(self):
return '<IntegrationAccount id={0.id} name={0.name!r}>'.format(self)
class Integration: class Integration:
"""Represents a guild integration. """Represents a guild integration.
@@ -82,83 +62,6 @@ class Integration:
Attributes Attributes
----------- -----------
id: :class:`int`
The integration ID.
name: :class:`str`
The integration name.
guild: :class:`Guild`
The guild of the integration.
type: :class:`str`
The integration type (i.e. Twitch).
enabled: :class:`bool`
Whether the integration is currently enabled.
account: :class:`IntegrationAccount`
The account linked to this integration.
user: :class:`User`
The user that added this integration.
"""
__slots__ = (
'guild',
'id',
'_state',
'type',
'name',
'account',
'user',
'enabled',
)
def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None:
self.guild = guild
self._state = guild._state
self._from_data(data)
def __repr__(self):
return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>"
def _from_data(self, data: IntegrationPayload) -> None:
self.id: int = int(data['id'])
self.type: IntegrationType = data['type']
self.name: str = data['name']
self.account: IntegrationAccount = IntegrationAccount(data['account'])
user = data.get('user')
self.user = User(state=self._state, data=user) if user else None
self.enabled: bool = data['enabled']
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the integration.
You must have the :attr:`~Permissions.manage_guild` permission to
do this.
Parameters
-----------
reason: :class:`str`
The reason the integration was deleted. Shows up on the audit log.
.. versionadded:: 2.0
Raises
-------
Forbidden
You do not have permission to delete the integration.
HTTPException
Deleting the integration failed.
"""
await self._state.http.delete_integration(self.guild.id, self.id, reason=reason)
class StreamIntegration(Integration):
"""Represents a stream integration for Twitch or YouTube.
.. versionadded:: 2.0
Attributes
----------
id: :class:`int` id: :class:`int`
The integration ID. The integration ID.
name: :class:`str` name: :class:`str`
@@ -171,6 +74,8 @@ class StreamIntegration(Integration):
Whether the integration is currently enabled. Whether the integration is currently enabled.
syncing: :class:`bool` syncing: :class:`bool`
Where the integration is currently syncing. Where the integration is currently syncing.
role: :class:`Role`
The role which the integration uses for subscribers.
enable_emoticons: Optional[:class:`bool`] enable_emoticons: Optional[:class:`bool`]
Whether emoticons should be synced for this integration (currently twitch only). Whether emoticons should be synced for this integration (currently twitch only).
expire_behaviour: :class:`ExpireBehaviour` expire_behaviour: :class:`ExpireBehaviour`
@@ -185,45 +90,37 @@ class StreamIntegration(Integration):
An aware UTC datetime representing when the integration was last synced. An aware UTC datetime representing when the integration was last synced.
""" """
__slots__ = ( __slots__ = ('id', '_state', 'guild', 'name', 'enabled', 'type',
'revoked', 'syncing', 'role', 'expire_behaviour', 'expire_behavior',
'expire_behaviour', 'expire_grace_period', 'synced_at', 'user', 'account',
'expire_grace_period', 'enable_emoticons', '_role_id')
'synced_at',
'_role_id',
'syncing',
'enable_emoticons',
'subscriber_count',
)
def _from_data(self, data: StreamIntegrationPayload) -> None: def __init__(self, *, data, guild):
super()._from_data(data) self.guild = guild
self.revoked: bool = data['revoked'] self._state = guild._state
self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data['expire_behavior']) self._from_data(data)
self.expire_grace_period: int = data['expire_grace_period']
self.synced_at: datetime.datetime = parse_time(data['synced_at'])
self._role_id: Optional[int] = _get_as_snowflake(data, 'role_id')
self.syncing: bool = data['syncing']
self.enable_emoticons: bool = data['enable_emoticons']
self.subscriber_count: int = data['subscriber_count']
@property def __repr__(self):
def expire_behavior(self) -> ExpireBehaviour: return '<Integration id={0.id} name={0.name!r} type={0.type!r}>'.format(self)
""":class:`ExpireBehaviour`: An alias for :attr:`expire_behaviour`."""
return self.expire_behaviour
@property def _from_data(self, integ):
def role(self) -> Optional[Role]: self.id = _get_as_snowflake(integ, 'id')
"""Optional[:class:`Role`] The role which the integration uses for subscribers.""" self.name = integ['name']
return self.guild.get_role(self._role_id) # type: ignore self.type = integ['type']
self.enabled = integ['enabled']
self.syncing = integ['syncing']
self._role_id = _get_as_snowflake(integ, 'role_id')
self.role = get(self.guild.roles, id=self._role_id)
self.enable_emoticons = integ.get('enable_emoticons')
self.expire_behaviour = try_enum(ExpireBehaviour, integ['expire_behavior'])
self.expire_behavior = self.expire_behaviour
self.expire_grace_period = integ['expire_grace_period']
self.synced_at = parse_time(integ['synced_at'])
async def edit( self.user = User(state=self._state, data=integ['user'])
self, self.account = IntegrationAccount(**integ['account'])
*,
expire_behaviour: ExpireBehaviour = MISSING, async def edit(self, **fields):
expire_grace_period: int = MISSING,
enable_emoticons: bool = MISSING,
) -> None:
"""|coro| """|coro|
Edits the integration. Edits the integration.
@@ -249,24 +146,34 @@ class StreamIntegration(Integration):
InvalidArgument InvalidArgument
``expire_behaviour`` did not receive a :class:`ExpireBehaviour`. ``expire_behaviour`` did not receive a :class:`ExpireBehaviour`.
""" """
payload: Dict[str, Any] = {} try:
if expire_behaviour is not MISSING: expire_behaviour = fields['expire_behaviour']
if not isinstance(expire_behaviour, ExpireBehaviour): except KeyError:
raise InvalidArgument('expire_behaviour field must be of type ExpireBehaviour') expire_behaviour = fields.get('expire_behavior', self.expire_behaviour)
payload['expire_behavior'] = expire_behaviour.value if not isinstance(expire_behaviour, ExpireBehaviour):
raise InvalidArgument('expire_behaviour field must be of type ExpireBehaviour')
if expire_grace_period is not MISSING: expire_grace_period = fields.get('expire_grace_period', self.expire_grace_period)
payload['expire_grace_period'] = expire_grace_period
if enable_emoticons is not MISSING: payload = {
'expire_behavior': expire_behaviour.value,
'expire_grace_period': expire_grace_period,
}
enable_emoticons = fields.get('enable_emoticons')
if enable_emoticons is not None:
payload['enable_emoticons'] = enable_emoticons payload['enable_emoticons'] = enable_emoticons
# This endpoint is undocumented.
# Unsure if it returns the data or not as a result
await self._state.http.edit_integration(self.guild.id, self.id, **payload) await self._state.http.edit_integration(self.guild.id, self.id, **payload)
async def sync(self) -> None: self.expire_behaviour = expire_behaviour
self.expire_behavior = self.expire_behaviour
self.expire_grace_period = expire_grace_period
self.enable_emoticons = enable_emoticons
async def sync(self):
"""|coro| """|coro|
Syncs the integration. Syncs the integration.
@@ -284,83 +191,19 @@ class StreamIntegration(Integration):
await self._state.http.sync_integration(self.guild.id, self.id) await self._state.http.sync_integration(self.guild.id, self.id)
self.synced_at = datetime.datetime.now(datetime.timezone.utc) self.synced_at = datetime.datetime.now(datetime.timezone.utc)
async def delete(self):
"""|coro|
class IntegrationApplication: Deletes the integration.
"""Represents an application for a bot integration.
.. versionadded:: 2.0 You must have the :attr:`~Permissions.manage_guild` permission to
do this.
Attributes Raises
---------- -------
id: :class:`int` Forbidden
The ID for this application. You do not have permission to delete the integration.
name: :class:`str` HTTPException
The application's name. Deleting the integration failed.
icon: Optional[:class:`str`] """
The application's icon hash. await self._state.http.delete_integration(self.guild.id, self.id)
description: :class:`str`
The application's description. Can be an empty string.
summary: :class:`str`
The summary of the application. Can be an empty string.
user: Optional[:class:`User`]
The bot user on this application.
"""
__slots__ = (
'id',
'name',
'icon',
'description',
'summary',
'user',
)
def __init__(self, *, data: IntegrationApplicationPayload, state):
self.id: int = int(data['id'])
self.name: str = data['name']
self.icon: Optional[str] = data['icon']
self.description: str = data['description']
self.summary: str = data['summary']
user = data.get('bot')
self.user: Optional[User] = User(state=state, data=user) if user else None
class BotIntegration(Integration):
"""Represents a bot integration on discord.
.. versionadded:: 2.0
Attributes
----------
id: :class:`int`
The integration ID.
name: :class:`str`
The integration name.
guild: :class:`Guild`
The guild of the integration.
type: :class:`str`
The integration type (i.e. Twitch).
enabled: :class:`bool`
Whether the integration is currently enabled.
user: :class:`User`
The user that added this integration.
account: :class:`IntegrationAccount`
The integration account information.
application: :class:`IntegrationApplication`
The application tied to this integration.
"""
__slots__ = ('application',)
def _from_data(self, data: BotIntegrationPayload) -> None:
super()._from_data(data)
self.application = IntegrationApplication(data=data['application'], state=self._state)
def _integration_factory(value: str) -> Tuple[Type[Integration], str]:
if value == 'discord':
return BotIntegration, value
elif value in ('twitch', 'youtube'):
return StreamIntegration, value
else:
return Integration, value

View File

@@ -25,54 +25,20 @@ DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union
import asyncio
from . import utils from . import utils
from .enums import try_enum, InteractionType, InteractionResponseType from .enums import try_enum, InteractionType
from .errors import InteractionResponded, HTTPException, ClientException
from .channel import PartialMessageable, ChannelType
from .user import User
from .member import Member
from .message import Message, Attachment
from .object import Object
from .permissions import Permissions
from .webhook.async_ import async_context, Webhook, handle_message_parameters
__all__ = ( __all__ = (
'Interaction', 'Interaction',
'InteractionMessage',
'InteractionResponse',
) )
if TYPE_CHECKING:
from .types.interactions import (
Interaction as InteractionPayload,
InteractionData,
)
from .guild import Guild
from .state import ConnectionState
from .file import File
from .mentions import AllowedMentions
from aiohttp import ClientSession
from .embeds import Embed
from .ui.view import View
from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable
from .threads import Thread
InteractionChannel = Union[
VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable
]
MISSING: Any = utils.MISSING
class Interaction: class Interaction:
"""Represents a Discord interaction. """Represents a Discord interaction.
An interaction happens when a user does an action that needs to An interaction happens when a user does an action that needs to
be notified. Current examples are slash commands and components. be notified. Current examples are slash commands but future examples
include forms and buttons.
.. versionadded:: 2.0 .. versionadded:: 2.0
@@ -90,678 +56,49 @@ class Interaction:
The application ID that the interaction was for. The application ID that the interaction was for.
user: Optional[Union[:class:`User`, :class:`Member`]] user: Optional[Union[:class:`User`, :class:`Member`]]
The user or member that sent the interaction. The user or member that sent the interaction.
message: Optional[:class:`Message`]
The message that sent this interaction.
token: :class:`str` token: :class:`str`
The token to continue the interaction. These are valid The token to continue the interaction. These are valid
for 15 minutes. for 15 minutes.
data: :class:`dict`
The raw interaction data.
""" """
__slots__ = (
__slots__: Tuple[str, ...] = (
'id', 'id',
'type', 'type',
'guild_id', 'guild_id',
'channel_id', 'channel_id',
'data', 'data',
'application_id', 'application_id',
'message',
'user', 'user',
'token', 'token',
'version', 'version',
'_permissions',
'_state', '_state',
'_session',
'_original_message',
'_cs_response',
'_cs_followup',
'_cs_channel',
) )
def __init__(self, *, data: InteractionPayload, state: ConnectionState): def __init__(self, *, data, state=None):
self._state: ConnectionState = state self._state = state
self._session: ClientSession = state.http._HTTPClient__session
self._original_message: Optional[InteractionMessage] = None
self._from_data(data) self._from_data(data)
def _from_data(self, data: InteractionPayload): def _from_data(self, data):
self.id: int = int(data['id']) self.id = int(data['id'])
self.type: InteractionType = try_enum(InteractionType, data['type']) self.type = try_enum(InteractionType, data['type'])
self.data: Optional[InteractionData] = data.get('data') self.data = data.get('data')
self.token: str = data['token'] self.token = data['token']
self.version: int = data['version'] self.version = data['version']
self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id') self.channel_id = utils._get_as_snowflake(data, 'channel_id')
self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') self.guild_id = utils._get_as_snowflake(data, 'guild_id')
self.application_id: int = int(data['application_id']) self.application_id = utils._get_as_snowflake(data, 'application_id')
self.message: Optional[Message]
try:
self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore
except KeyError:
self.message = None
self.user: Optional[Union[User, Member]] = None
self._permissions: int = 0
# TODO: there's a potential data loss here
if self.guild_id:
guild = self.guild or Object(id=self.guild_id)
try:
member = data['member'] # type: ignore
except KeyError:
pass
else:
self.user = Member(state=self._state, guild=guild, data=member) # type: ignore
self._permissions = int(member.get('permissions', 0))
else:
try:
self.user = User(state=self._state, data=data['user'])
except KeyError:
pass
@property @property
def guild(self) -> Optional[Guild]: def guild(self):
"""Optional[:class:`Guild`]: The guild the interaction was sent from.""" """Optional[:class:`Guild`]: The guild the interaction was sent from."""
return self._state and self._state._get_guild(self.guild_id) return self._state and self._state.get_guild(self.guild_id)
@utils.cached_slot_property('_cs_channel') @property
def channel(self) -> Optional[InteractionChannel]: def channel(self):
"""Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from. """Optional[:class:`abc.GuildChannel`]: The channel the interaction was sent from.
Note that due to a Discord limitation, DM channels are not resolved since there is Note that due to a Discord limitation, DM channels are not resolved since there is
no data to complete them. These are :class:`PartialMessageable` instead. no data to complete them.
""" """
guild = self.guild guild = self.guild
channel = guild and guild._resolve_channel(self.channel_id) return guild and guild.get_channel(self.channel_id)
if channel is None:
if self.channel_id is not None:
type = ChannelType.text if self.guild_id is not None else ChannelType.private
return PartialMessageable(state=self._state, id=self.channel_id, type=type)
return None
return channel
@property
def permissions(self) -> Permissions:
""":class:`Permissions`: The resolved permissions of the member in the channel, including overwrites.
In a non-guild context where this doesn't apply, an empty permissions object is returned.
"""
return Permissions(self._permissions)
@utils.cached_slot_property('_cs_response')
def response(self) -> InteractionResponse:
""":class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction.
A response can only be done once. If secondary messages need to be sent, consider using :attr:`followup`
instead.
"""
return InteractionResponse(self)
@utils.cached_slot_property('_cs_followup')
def followup(self) -> Webhook:
""":class:`Webhook`: Returns the follow up webhook for follow up interactions."""
payload = {
'id': self.application_id,
'type': 3,
'token': self.token,
}
return Webhook.from_state(data=payload, state=self._state)
async def original_message(self) -> InteractionMessage:
"""|coro|
Fetches the original interaction response message associated with the interaction.
If the interaction response was :meth:`InteractionResponse.send_message` then this would
return the message that was sent using that response. Otherwise, this would return
the message that triggered the interaction.
Repeated calls to this will return a cached value.
Raises
-------
HTTPException
Fetching the original response message failed.
ClientException
The channel for the message could not be resolved.
Returns
--------
InteractionMessage
The original interaction response message.
"""
if self._original_message is not None:
return self._original_message
# TODO: fix later to not raise?
channel = self.channel
if channel is None:
raise ClientException('Channel for message could not be resolved')
adapter = async_context.get()
data = await adapter.get_original_interaction_response(
application_id=self.application_id,
token=self.token,
session=self._session,
)
state = _InteractionMessageState(self, self._state)
message = InteractionMessage(state=state, channel=channel, data=data) # type: ignore
self._original_message = message
return message
async def edit_original_message(
self,
*,
content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING,
embed: Optional[Embed] = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
) -> InteractionMessage:
"""|coro|
Edits the original interaction response message.
This is a lower level interface to :meth:`InteractionMessage.edit` in case
you do not want to fetch the message and save an HTTP request.
This method is also the only way to edit the original message if
the message sent was ephemeral.
Parameters
------------
content: Optional[:class:`str`]
The content to edit the message with or ``None`` to clear it.
embeds: List[:class:`Embed`]
A list of embeds to edit the message with.
embed: Optional[:class:`Embed`]
The embed to edit the message with. ``None`` suppresses the embeds.
This should not be mixed with the ``embeds`` parameter.
file: :class:`File`
The file to upload. This cannot be mixed with ``files`` parameter.
files: List[:class:`File`]
A list of files to send with the content. This cannot be mixed with the
``file`` parameter.
allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information.
view: Optional[:class:`~discord.ui.View`]
The updated view to update this message with. If ``None`` is passed then
the view is removed.
Raises
-------
HTTPException
Editing the message failed.
Forbidden
Edited a message that is not yours.
TypeError
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
ValueError
The length of ``embeds`` was invalid.
Returns
--------
:class:`InteractionMessage`
The newly edited message.
"""
previous_mentions: Optional[AllowedMentions] = self._state.allowed_mentions
params = handle_message_parameters(
content=content,
file=file,
files=files,
embed=embed,
embeds=embeds,
view=view,
allowed_mentions=allowed_mentions,
previous_allowed_mentions=previous_mentions,
)
adapter = async_context.get()
data = await adapter.edit_original_interaction_response(
self.application_id,
self.token,
session=self._session,
payload=params.payload,
multipart=params.multipart,
files=params.files,
)
# The message channel types should always match
message = InteractionMessage(state=self._state, channel=self.channel, data=data) # type: ignore
if view and not view.is_finished():
self._state.store_view(view, message.id)
return message
async def delete_original_message(self) -> None:
"""|coro|
Deletes the original interaction response message.
This is a lower level interface to :meth:`InteractionMessage.delete` in case
you do not want to fetch the message and save an HTTP request.
Raises
-------
HTTPException
Deleting the message failed.
Forbidden
Deleted a message that is not yours.
"""
adapter = async_context.get()
await adapter.delete_original_interaction_response(
self.application_id,
self.token,
session=self._session,
)
class InteractionResponse:
"""Represents a Discord interaction response.
This type can be accessed through :attr:`Interaction.response`.
.. versionadded:: 2.0
"""
__slots__: Tuple[str, ...] = (
'_responded',
'_parent',
)
def __init__(self, parent: Interaction):
self._parent: Interaction = parent
self._responded: bool = False
def is_done(self) -> bool:
""":class:`bool`: Indicates whether an interaction response has been done before.
An interaction can only be responded to once.
"""
return self._responded
async def defer(self, *, ephemeral: bool = False) -> None:
"""|coro|
Defers the interaction response.
This is typically used when the interaction is acknowledged
and a secondary action will be done later.
Parameters
-----------
ephemeral: :class:`bool`
Indicates whether the deferred message will eventually be ephemeral.
This only applies for interactions of type :attr:`InteractionType.application_command`.
Raises
-------
HTTPException
Deferring the interaction failed.
InteractionResponded
This interaction has already been responded to before.
"""
if self._responded:
raise InteractionResponded(self._parent)
defer_type: int = 0
data: Optional[Dict[str, Any]] = None
parent = self._parent
if parent.type is InteractionType.component:
defer_type = InteractionResponseType.deferred_message_update.value
elif parent.type is InteractionType.application_command:
defer_type = InteractionResponseType.deferred_channel_message.value
if ephemeral:
data = {'flags': 64}
if defer_type:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=defer_type, data=data
)
self._responded = True
async def pong(self) -> None:
"""|coro|
Pongs the ping interaction.
This should rarely be used.
Raises
-------
HTTPException
Ponging the interaction failed.
InteractionResponded
This interaction has already been responded to before.
"""
if self._responded:
raise InteractionResponded(self._parent)
parent = self._parent
if parent.type is InteractionType.ping:
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value
)
self._responded = True
async def send_message(
self,
content: Optional[Any] = None,
*,
embed: Embed = MISSING,
embeds: List[Embed] = MISSING,
view: View = MISSING,
tts: bool = False,
ephemeral: bool = False,
) -> None:
"""|coro|
Responds to this interaction by sending a message.
Parameters
-----------
content: Optional[:class:`str`]
The content of the message to send.
embeds: List[:class:`Embed`]
A list of embeds to send with the content. Maximum of 10. This cannot
be mixed with the ``embed`` parameter.
embed: :class:`Embed`
The rich embed for the content to send. This cannot be mixed with
``embeds`` parameter.
tts: :class:`bool`
Indicates if the message should be sent using text-to-speech.
view: :class:`discord.ui.View`
The view to send with the message.
ephemeral: :class:`bool`
Indicates if the message should only be visible to the user who started the interaction.
If a view is sent with an ephemeral message and it has no timeout set then the timeout
is set to 15 minutes.
Raises
-------
HTTPException
Sending the message failed.
TypeError
You specified both ``embed`` and ``embeds``.
ValueError
The length of ``embeds`` was invalid.
InteractionResponded
This interaction has already been responded to before.
"""
if self._responded:
raise InteractionResponded(self._parent)
payload: Dict[str, Any] = {
'tts': tts,
}
if embed is not MISSING and embeds is not MISSING:
raise TypeError('cannot mix embed and embeds keyword arguments')
if embed is not MISSING:
embeds = [embed]
if embeds:
if len(embeds) > 10:
raise ValueError('embeds cannot exceed maximum of 10 elements')
payload['embeds'] = [e.to_dict() for e in embeds]
if content is not None:
payload['content'] = str(content)
if ephemeral:
payload['flags'] = 64
if view is not MISSING:
payload['components'] = view.to_components()
parent = self._parent
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.channel_message.value,
data=payload,
)
if view is not MISSING:
if ephemeral and view.timeout is None:
view.timeout = 15 * 60.0
self._parent._state.store_view(view)
self._responded = True
async def edit_message(
self,
*,
content: Optional[Any] = MISSING,
embed: Optional[Embed] = MISSING,
embeds: List[Embed] = MISSING,
attachments: List[Attachment] = MISSING,
view: Optional[View] = MISSING,
) -> None:
"""|coro|
Responds to this interaction by editing the original message of
a component interaction.
Parameters
-----------
content: Optional[:class:`str`]
The new content to replace the message with. ``None`` removes the content.
embeds: List[:class:`Embed`]
A list of embeds to edit the message with.
embed: Optional[:class:`Embed`]
The embed to edit the message with. ``None`` suppresses the embeds.
This should not be mixed with the ``embeds`` parameter.
attachments: List[:class:`Attachment`]
A list of attachments to keep in the message. If ``[]`` is passed
then all attachments are removed.
view: Optional[:class:`~discord.ui.View`]
The updated view to update this message with. If ``None`` is passed then
the view is removed.
Raises
-------
HTTPException
Editing the message failed.
TypeError
You specified both ``embed`` and ``embeds``.
InteractionResponded
This interaction has already been responded to before.
"""
if self._responded:
raise InteractionResponded(self._parent)
parent = self._parent
msg = parent.message
state = parent._state
message_id = msg.id if msg else None
if parent.type is not InteractionType.component:
return
payload = {}
if content is not MISSING:
if content is None:
payload['content'] = None
else:
payload['content'] = str(content)
if embed is not MISSING and embeds is not MISSING:
raise TypeError('cannot mix both embed and embeds keyword arguments')
if embed is not MISSING:
if embed is None:
embeds = []
else:
embeds = [embed]
if embeds is not MISSING:
payload['embeds'] = [e.to_dict() for e in embeds]
if attachments is not MISSING:
payload['attachments'] = [a.to_dict() for a in attachments]
if view is not MISSING:
state.prevent_view_updates_for(message_id)
if view is None:
payload['components'] = []
else:
payload['components'] = view.to_components()
adapter = async_context.get()
await adapter.create_interaction_response(
parent.id,
parent.token,
session=parent._session,
type=InteractionResponseType.message_update.value,
data=payload,
)
if view and not view.is_finished():
state.store_view(view, message_id)
self._responded = True
class _InteractionMessageState:
__slots__ = ('_parent', '_interaction')
def __init__(self, interaction: Interaction, parent: ConnectionState):
self._interaction: Interaction = interaction
self._parent: ConnectionState = parent
def _get_guild(self, guild_id):
return self._parent._get_guild(guild_id)
def store_user(self, data):
return self._parent.store_user(data)
def create_user(self, data):
return self._parent.create_user(data)
@property
def http(self):
return self._parent.http
def __getattr__(self, attr):
return getattr(self._parent, attr)
class InteractionMessage(Message):
"""Represents the original interaction response message.
This allows you to edit or delete the message associated with
the interaction response. To retrieve this object see :meth:`Interaction.original_message`.
This inherits from :class:`discord.Message` with changes to
:meth:`edit` and :meth:`delete` to work.
.. versionadded:: 2.0
"""
__slots__ = ()
_state: _InteractionMessageState
async def edit(
self,
content: Optional[str] = MISSING,
embeds: List[Embed] = MISSING,
embed: Optional[Embed] = MISSING,
file: File = MISSING,
files: List[File] = MISSING,
view: Optional[View] = MISSING,
allowed_mentions: Optional[AllowedMentions] = None,
) -> InteractionMessage:
"""|coro|
Edits the message.
Parameters
------------
content: Optional[:class:`str`]
The content to edit the message with or ``None`` to clear it.
embeds: List[:class:`Embed`]
A list of embeds to edit the message with.
embed: Optional[:class:`Embed`]
The embed to edit the message with. ``None`` suppresses the embeds.
This should not be mixed with the ``embeds`` parameter.
file: :class:`File`
The file to upload. This cannot be mixed with ``files`` parameter.
files: List[:class:`File`]
A list of files to send with the content. This cannot be mixed with the
``file`` parameter.
allowed_mentions: :class:`AllowedMentions`
Controls the mentions being processed in this message.
See :meth:`.abc.Messageable.send` for more information.
view: Optional[:class:`~discord.ui.View`]
The updated view to update this message with. If ``None`` is passed then
the view is removed.
Raises
-------
HTTPException
Editing the message failed.
Forbidden
Edited a message that is not yours.
TypeError
You specified both ``embed`` and ``embeds`` or ``file`` and ``files``
ValueError
The length of ``embeds`` was invalid.
Returns
---------
:class:`InteractionMessage`
The newly edited message.
"""
return await self._state._interaction.edit_original_message(
content=content,
embeds=embeds,
embed=embed,
file=file,
files=files,
view=view,
allowed_mentions=allowed_mentions,
)
async def delete(self, *, delay: Optional[float] = None) -> None:
"""|coro|
Deletes the message.
Parameters
-----------
delay: Optional[:class:`float`]
If provided, the number of seconds to wait before deleting the message.
The waiting is done in the background and deletion failures are ignored.
Raises
------
Forbidden
You do not have proper permissions to delete the message.
NotFound
The message was deleted already.
HTTPException
Deleting the message failed.
"""
if delay is not None:
async def inner_call(delay: float = delay):
await asyncio.sleep(delay)
try:
await self._state._interaction.delete_original_message()
except HTTPException:
pass
asyncio.create_task(inner_call())
else:
await self._state._interaction.delete_original_message()

View File

@@ -22,15 +22,11 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import List, Optional, Type, TypeVar, Union, TYPE_CHECKING
from .asset import Asset from .asset import Asset
from .utils import parse_time, snowflake_time, _get_as_snowflake from .utils import parse_time, snowflake_time, _get_as_snowflake
from .object import Object from .object import Object
from .mixins import Hashable from .mixins import Hashable
from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum from .enums import ChannelType, VerificationLevel, try_enum
from .appinfo import PartialAppInfo
__all__ = ( __all__ = (
'PartialInviteChannel', 'PartialInviteChannel',
@@ -38,26 +34,6 @@ __all__ = (
'Invite', 'Invite',
) )
if TYPE_CHECKING:
from .types.invite import (
Invite as InvitePayload,
InviteGuild as InviteGuildPayload,
GatewayInvite as GatewayInvitePayload,
)
from .types.channel import (
PartialChannel as InviteChannelPayload,
)
from .state import ConnectionState
from .guild import Guild
from .abc import GuildChannel
from .user import User
InviteGuildType = Union[Guild, 'PartialInviteGuild', Object]
InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object]
import datetime
class PartialInviteChannel: class PartialInviteChannel:
"""Represents a "partial" invite channel. """Represents a "partial" invite channel.
@@ -94,28 +70,27 @@ class PartialInviteChannel:
__slots__ = ('id', 'name', 'type') __slots__ = ('id', 'name', 'type')
def __init__(self, data: InviteChannelPayload): def __init__(self, **kwargs):
self.id: int = int(data['id']) self.id = kwargs.pop('id')
self.name: str = data['name'] self.name = kwargs.pop('name')
self.type: ChannelType = try_enum(ChannelType, data['type']) self.type = kwargs.pop('type')
def __str__(self) -> str: def __str__(self):
return self.name return self.name
def __repr__(self) -> str: def __repr__(self):
return f'<PartialInviteChannel id={self.id} name={self.name} type={self.type!r}>' return '<PartialInviteChannel id={0.id} name={0.name} type={0.type!r}>'.format(self)
@property @property
def mention(self) -> str: def mention(self):
""":class:`str`: The string that allows you to mention the channel.""" """:class:`str`: The string that allows you to mention the channel."""
return f'<#{self.id}>' return f'<#{self.id}>'
@property @property
def created_at(self) -> datetime.datetime: def created_at(self):
""":class:`datetime.datetime`: Returns the channel's creation time in UTC.""" """:class:`datetime.datetime`: Returns the channel's creation time in UTC."""
return snowflake_time(self.id) return snowflake_time(self.id)
class PartialInviteGuild: class PartialInviteGuild:
"""Represents a "partial" invite guild. """Represents a "partial" invite guild.
@@ -150,61 +125,93 @@ class PartialInviteGuild:
The partial guild's verification level. The partial guild's verification level.
features: List[:class:`str`] features: List[:class:`str`]
A list of features the guild has. See :attr:`Guild.features` for more information. A list of features the guild has. See :attr:`Guild.features` for more information.
icon: Optional[:class:`str`]
The partial guild's icon.
banner: Optional[:class:`str`]
The partial guild's banner.
splash: Optional[:class:`str`]
The partial guild's invite splash.
description: Optional[:class:`str`] description: Optional[:class:`str`]
The partial guild's description. The partial guild's description.
""" """
__slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description') __slots__ = ('_state', 'features', 'icon', 'banner', 'id', 'name', 'splash',
'verification_level', 'description')
def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int): def __init__(self, state, data, id):
self._state: ConnectionState = state self._state = state
self.id: int = id self.id = id
self.name: str = data['name'] self.name = data['name']
self.features: List[str] = data.get('features', []) self.features = data.get('features', [])
self._icon: Optional[str] = data.get('icon') self.icon = data.get('icon')
self._banner: Optional[str] = data.get('banner') self.banner = data.get('banner')
self._splash: Optional[str] = data.get('splash') self.splash = data.get('splash')
self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level')) self.verification_level = try_enum(VerificationLevel, data.get('verification_level'))
self.description: Optional[str] = data.get('description') self.description = data.get('description')
def __str__(self) -> str: def __str__(self):
return self.name return self.name
def __repr__(self) -> str: def __repr__(self):
return ( return '<{0.__class__.__name__} id={0.id} name={0.name!r} features={0.features} ' \
f'<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} ' 'description={0.description!r}>'.format(self)
f'description={self.description!r}>'
)
@property @property
def created_at(self) -> datetime.datetime: def created_at(self):
""":class:`datetime.datetime`: Returns the guild's creation time in UTC.""" """:class:`datetime.datetime`: Returns the guild's creation time in UTC."""
return snowflake_time(self.id) return snowflake_time(self.id)
@property @property
def icon(self) -> Optional[Asset]: def icon_url(self):
"""Optional[:class:`Asset`]: Returns the guild's icon asset, if available.""" """:class:`Asset`: Returns the guild's icon asset."""
if self._icon is None: return self.icon_url_as()
return None
return Asset._from_guild_icon(self._state, self.id, self._icon) def is_icon_animated(self):
""":class:`bool`: Returns ``True`` if the guild has an animated icon.
.. versionadded:: 1.4
"""
return bool(self.icon and self.icon.startswith('a_'))
def icon_url_as(self, *, format=None, static_format='webp', size=1024):
"""The same operation as :meth:`Guild.icon_url_as`.
Returns
--------
:class:`Asset`
The resulting CDN asset.
"""
return Asset._from_guild_icon(self._state, self, format=format, static_format=static_format, size=size)
@property @property
def banner(self) -> Optional[Asset]: def banner_url(self):
"""Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" """:class:`Asset`: Returns the guild's banner asset."""
if self._banner is None: return self.banner_url_as()
return None
return Asset._from_guild_image(self._state, self.id, self._banner, path='banners') def banner_url_as(self, *, format='webp', size=2048):
"""The same operation as :meth:`Guild.banner_url_as`.
Returns
--------
:class:`Asset`
The resulting CDN asset.
"""
return Asset._from_guild_image(self._state, self.id, self.banner, 'banners', format=format, size=size)
@property @property
def splash(self) -> Optional[Asset]: def splash_url(self):
"""Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" """:class:`Asset`: Returns the guild's invite splash asset."""
if self._splash is None: return self.splash_url_as()
return None
return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes')
def splash_url_as(self, *, format='webp', size=2048):
"""The same operation as :meth:`Guild.splash_url_as`.
I = TypeVar('I', bound='Invite') Returns
--------
:class:`Asset`
The resulting CDN asset.
"""
return Asset._from_guild_image(self._state, self.id, self.splash, 'splashes', format=format, size=size)
class Invite(Hashable): class Invite(Hashable):
r"""Represents a Discord :class:`Guild` or :class:`abc.GuildChannel` invite. r"""Represents a Discord :class:`Guild` or :class:`abc.GuildChannel` invite.
@@ -232,32 +239,30 @@ class Invite(Hashable):
The following table illustrates what methods will obtain the attributes: The following table illustrates what methods will obtain the attributes:
+------------------------------------+------------------------------------------------------------+ +------------------------------------+----------------------------------------------------------+
| Attribute | Method | | Attribute | Method |
+====================================+============================================================+ +====================================+==========================================================+
| :attr:`max_age` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` | | :attr:`max_age` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` |
+------------------------------------+------------------------------------------------------------+ +------------------------------------+----------------------------------------------------------+
| :attr:`max_uses` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` | | :attr:`max_uses` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` |
+------------------------------------+------------------------------------------------------------+ +------------------------------------+----------------------------------------------------------+
| :attr:`created_at` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` | | :attr:`created_at` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` |
+------------------------------------+------------------------------------------------------------+ +------------------------------------+----------------------------------------------------------+
| :attr:`temporary` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` | | :attr:`temporary` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` |
+------------------------------------+------------------------------------------------------------+ +------------------------------------+----------------------------------------------------------+
| :attr:`uses` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` | | :attr:`uses` | :meth:`abc.GuildChannel.invites`\, :meth:`Guild.invites` |
+------------------------------------+------------------------------------------------------------+ +------------------------------------+----------------------------------------------------------+
| :attr:`approximate_member_count` | :meth:`Client.fetch_invite` with `with_counts` enabled | | :attr:`approximate_member_count` | :meth:`Client.fetch_invite` |
+------------------------------------+------------------------------------------------------------+ +------------------------------------+----------------------------------------------------------+
| :attr:`approximate_presence_count` | :meth:`Client.fetch_invite` with `with_counts` enabled | | :attr:`approximate_presence_count` | :meth:`Client.fetch_invite` |
+------------------------------------+------------------------------------------------------------+ +------------------------------------+----------------------------------------------------------+
| :attr:`expires_at` | :meth:`Client.fetch_invite` with `with_expiration` enabled |
+------------------------------------+------------------------------------------------------------+
If it's not in the table above then it is available by all methods. If it's not in the table above then it is available by all methods.
Attributes Attributes
----------- -----------
max_age: :class:`int` max_age: :class:`int`
How long before the invite expires in seconds. How long the before the invite expires in seconds.
A value of ``0`` indicates that it doesn't expire. A value of ``0`` indicates that it doesn't expire.
code: :class:`str` code: :class:`str`
The URL fragment used for the invite. The URL fragment used for the invite.
@@ -275,185 +280,105 @@ class Invite(Hashable):
max_uses: :class:`int` max_uses: :class:`int`
How many times the invite can be used. How many times the invite can be used.
A value of ``0`` indicates that it has unlimited uses. A value of ``0`` indicates that it has unlimited uses.
inviter: Optional[:class:`User`] inviter: :class:`User`
The user who created the invite. The user who created the invite.
approximate_member_count: Optional[:class:`int`] approximate_member_count: Optional[:class:`int`]
The approximate number of members in the guild. The approximate number of members in the guild.
approximate_presence_count: Optional[:class:`int`] approximate_presence_count: Optional[:class:`int`]
The approximate number of members currently active in the guild. The approximate number of members currently active in the guild.
This includes idle, dnd, online, and invisible members. Offline members are excluded. This includes idle, dnd, online, and invisible members. Offline members are excluded.
expires_at: Optional[:class:`datetime.datetime`]
The expiration date of the invite. If the value is ``None`` when received through
`Client.fetch_invite` with `with_expiration` enabled, the invite will never expire.
.. versionadded:: 2.0
channel: Union[:class:`abc.GuildChannel`, :class:`Object`, :class:`PartialInviteChannel`] channel: Union[:class:`abc.GuildChannel`, :class:`Object`, :class:`PartialInviteChannel`]
The channel the invite is for. The channel the invite is for.
target_type: :class:`InviteTarget`
The type of target for the voice channel invite.
.. versionadded:: 2.0
target_user: Optional[:class:`User`]
The user whose stream to display for this invite, if any.
.. versionadded:: 2.0
target_application: Optional[:class:`PartialAppInfo`]
The embedded application the invite targets, if any.
.. versionadded:: 2.0
""" """
__slots__ = ( __slots__ = ('max_age', 'code', 'guild', 'revoked', 'created_at', 'uses',
'max_age', 'temporary', 'max_uses', 'inviter', 'channel', '_state',
'code', 'approximate_member_count', 'approximate_presence_count' )
'guild',
'revoked',
'created_at',
'uses',
'temporary',
'max_uses',
'inviter',
'channel',
'target_user',
'target_type',
'_state',
'approximate_member_count',
'approximate_presence_count',
'target_application',
'expires_at',
)
BASE = 'https://discord.gg' BASE = 'https://discord.gg'
def __init__( def __init__(self, *, state, data):
self, self._state = state
*, self.max_age = data.get('max_age')
state: ConnectionState, self.code = data.get('code')
data: InvitePayload, self.guild = data.get('guild')
guild: Optional[Union[PartialInviteGuild, Guild]] = None, self.revoked = data.get('revoked')
channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None, self.created_at = parse_time(data.get('created_at'))
): self.temporary = data.get('temporary')
self._state: ConnectionState = state self.uses = data.get('uses')
self.max_age: Optional[int] = data.get('max_age') self.max_uses = data.get('max_uses')
self.code: str = data['code'] self.approximate_presence_count = data.get('approximate_presence_count')
self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild) self.approximate_member_count = data.get('approximate_member_count')
self.revoked: Optional[bool] = data.get('revoked')
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at'))
self.temporary: Optional[bool] = data.get('temporary')
self.uses: Optional[int] = data.get('uses')
self.max_uses: Optional[int] = data.get('max_uses')
self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count')
self.approximate_member_count: Optional[int] = data.get('approximate_member_count')
expires_at = data.get('expires_at', None)
self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None
inviter_data = data.get('inviter') inviter_data = data.get('inviter')
self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data) self.inviter = None if inviter_data is None else self._state.store_user(inviter_data)
self.channel = data.get('channel')
self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel)
target_user_data = data.get('target_user')
self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data)
self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0))
application = data.get('target_application')
self.target_application: Optional[PartialAppInfo] = (
PartialAppInfo(data=application, state=state) if application else None
)
@classmethod @classmethod
def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I: def from_incomplete(cls, *, state, data):
guild: Optional[Union[Guild, PartialInviteGuild]]
try: try:
guild_data = data['guild'] guild_id = int(data['guild']['id'])
except KeyError: except KeyError:
# If we're here, then this is a group DM # If we're here, then this is a group DM
guild = None guild = None
else: else:
guild_id = int(guild_data['id'])
guild = state._get_guild(guild_id) guild = state._get_guild(guild_id)
if guild is None: if guild is None:
# If it's not cached, then it has to be a partial guild # If it's not cached, then it has to be a partial guild
guild_data = data['guild']
guild = PartialInviteGuild(state, guild_data, guild_id) guild = PartialInviteGuild(state, guild_data, guild_id)
# As far as I know, invites always need a channel # As far as I know, invites always need a channel
# So this should never raise. # So this should never raise.
channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel']) channel_data = data['channel']
channel_id = int(channel_data['id'])
channel_type = try_enum(ChannelType, channel_data['type'])
channel = PartialInviteChannel(id=channel_id, name=channel_data['name'], type=channel_type)
if guild is not None and not isinstance(guild, PartialInviteGuild): if guild is not None and not isinstance(guild, PartialInviteGuild):
# Upgrade the partial data if applicable # Upgrade the partial data if applicable
channel = guild.get_channel(channel.id) or channel channel = guild.get_channel(channel_id) or channel
return cls(state=state, data=data, guild=guild, channel=channel) data['guild'] = guild
data['channel'] = channel
return cls(state=state, data=data)
@classmethod @classmethod
def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: def from_gateway(cls, *, state, data):
guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') guild_id = _get_as_snowflake(data, 'guild_id')
guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) guild = state._get_guild(guild_id)
channel_id = int(data['channel_id']) channel_id = _get_as_snowflake(data, 'channel_id')
if guild is not None: if guild is not None:
channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore channel = guild.get_channel(channel_id) or Object(id=channel_id)
else: else:
guild = Object(id=guild_id) if guild_id is not None else None guild = Object(id=guild_id)
channel = Object(id=channel_id) channel = Object(id=channel_id)
return cls(state=state, data=data, guild=guild, channel=channel) # type: ignore data['guild'] = guild
data['channel'] = channel
return cls(state=state, data=data)
def _resolve_guild( def __str__(self):
self,
data: Optional[InviteGuildPayload],
guild: Optional[Union[Guild, PartialInviteGuild]] = None,
) -> Optional[InviteGuildType]:
if guild is not None:
return guild
if data is None:
return None
guild_id = int(data['id'])
return PartialInviteGuild(self._state, data, guild_id)
def _resolve_channel(
self,
data: Optional[InviteChannelPayload],
channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None,
) -> Optional[InviteChannelType]:
if channel is not None:
return channel
if data is None:
return None
return PartialInviteChannel(data)
def __str__(self) -> str:
return self.url return self.url
def __repr__(self) -> str: def __repr__(self):
return ( return '<Invite code={0.code!r} guild={0.guild!r} ' \
f'<Invite code={self.code!r} guild={self.guild!r} ' 'online={0.approximate_presence_count} ' \
f'online={self.approximate_presence_count} ' 'members={0.approximate_member_count}>'.format(self)
f'members={self.approximate_member_count}>'
)
def __hash__(self) -> int: def __hash__(self):
return hash(self.code) return hash(self.code)
@property @property
def id(self) -> str: def id(self):
""":class:`str`: Returns the proper code portion of the invite.""" """:class:`str`: Returns the proper code portion of the invite."""
return self.code return self.code
@property @property
def url(self) -> str: def url(self):
""":class:`str`: A property that retrieves the invite URL.""" """:class:`str`: A property that retrieves the invite URL."""
return self.BASE + '/' + self.code return self.BASE + '/' + self.code
async def delete(self, *, reason: Optional[str] = None): async def delete(self, *, reason=None):
"""|coro| """|coro|
Revokes the instant invite. Revokes the instant invite.

View File

@@ -26,10 +26,10 @@ from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator from typing import TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator, Coroutine
from .errors import NoMoreItems from .errors import NoMoreItems
from .utils import snowflake_time, time_snowflake, maybe_coroutine from .utils import time_snowflake, maybe_coroutine
from .object import Object from .object import Object
from .audit_logs import AuditLogEntry from .audit_logs import AuditLogEntry
@@ -42,46 +42,24 @@ __all__ = (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.audit_log import (
AuditLog as AuditLogPayload,
)
from .types.guild import (
Guild as GuildPayload,
)
from .types.message import (
Message as MessagePayload,
)
from .types.user import (
PartialUser as PartialUserPayload,
)
from .types.threads import (
Thread as ThreadPayload,
)
from .member import Member from .member import Member
from .user import User from .user import User
from .message import Message from .message import Message
from .audit_logs import AuditLogEntry from .audit_logs import AuditLogEntry
from .guild import Guild from .guild import Guild
from .threads import Thread
from .abc import Snowflake
T = TypeVar('T') T = TypeVar('T')
OT = TypeVar('OT') OT = TypeVar('OT')
_Func = Callable[[T], Union[OT, Awaitable[OT]]] _Func = Callable[[T], Union[OT, Coroutine[Any, Any, OT]]]
_Predicate = Callable[[T], Union[T, Coroutine[Any, Any, T]]]
OLDEST_OBJECT = Object(id=0) OLDEST_OBJECT = Object(id=0)
class _AsyncIterator(AsyncIterator[T]): class _AsyncIterator(AsyncIterator[T]):
__slots__ = () __slots__ = ()
async def next(self) -> T: def get(self, **attrs: Any) -> Optional[T]:
raise NotImplementedError def predicate(elem):
def get(self, **attrs: Any) -> Awaitable[Optional[T]]:
def predicate(elem: T):
for attr, val in attrs.items(): for attr, val in attrs.items():
nested = attr.split('__') nested = attr.split('__')
obj = elem obj = elem
@@ -94,7 +72,7 @@ class _AsyncIterator(AsyncIterator[T]):
return self.find(predicate) return self.find(predicate)
async def find(self, predicate: _Func[T, bool]) -> Optional[T]: async def find(self, predicate: _Predicate[T]) -> Optional[T]:
while True: while True:
try: try:
elem = await self.next() elem = await self.next()
@@ -113,7 +91,7 @@ class _AsyncIterator(AsyncIterator[T]):
def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]:
return _MappedAsyncIterator(self, func) return _MappedAsyncIterator(self, func)
def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]: def filter(self, predicate: _Predicate[T]) -> _FilteredAsyncIterator[T]:
return _FilteredAsyncIterator(self, predicate) return _FilteredAsyncIterator(self, predicate)
async def flatten(self) -> List[T]: async def flatten(self) -> List[T]:
@@ -125,18 +103,16 @@ class _AsyncIterator(AsyncIterator[T]):
except NoMoreItems: except NoMoreItems:
raise StopAsyncIteration() raise StopAsyncIteration()
def _identity(x): def _identity(x):
return x return x
class _ChunkedAsyncIterator(_AsyncIterator[T]):
class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
def __init__(self, iterator, max_size): def __init__(self, iterator, max_size):
self.iterator = iterator self.iterator = iterator
self.max_size = max_size self.max_size = max_size
async def next(self) -> List[T]: async def next(self) -> T:
ret: List[T] = [] ret = []
n = 0 n = 0
while n < self.max_size: while n < self.max_size:
try: try:
@@ -150,7 +126,6 @@ class _ChunkedAsyncIterator(_AsyncIterator[List[T]]):
n += 1 n += 1
return ret return ret
class _MappedAsyncIterator(_AsyncIterator[T]): class _MappedAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, func): def __init__(self, iterator, func):
self.iterator = iterator self.iterator = iterator
@@ -161,7 +136,6 @@ class _MappedAsyncIterator(_AsyncIterator[T]):
item = await self.iterator.next() item = await self.iterator.next()
return await maybe_coroutine(self.func, item) return await maybe_coroutine(self.func, item)
class _FilteredAsyncIterator(_AsyncIterator[T]): class _FilteredAsyncIterator(_AsyncIterator[T]):
def __init__(self, iterator, predicate): def __init__(self, iterator, predicate):
self.iterator = iterator self.iterator = iterator
@@ -181,7 +155,6 @@ class _FilteredAsyncIterator(_AsyncIterator[T]):
if ret: if ret:
return item return item
class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
def __init__(self, message, emoji, limit=100, after=None): def __init__(self, message, emoji, limit=100, after=None):
self.message = message self.message = message
@@ -195,7 +168,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
self.channel_id = message.channel.id self.channel_id = message.channel.id
self.users = asyncio.Queue() self.users = asyncio.Queue()
async def next(self) -> Union[User, Member]: async def next(self) -> T:
if self.users.empty(): if self.users.empty():
await self.fill_users() await self.fill_users()
@@ -212,9 +185,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
retrieve = self.limit if self.limit <= 100 else 100 retrieve = self.limit if self.limit <= 100 else 100
after = self.after.id if self.after else None after = self.after.id if self.after else None
data: List[PartialUserPayload] = await self.getter( data = await self.getter(self.channel_id, self.message.id, self.emoji, retrieve, after=after)
self.channel_id, self.message.id, self.emoji, retrieve, after=after
)
if data: if data:
self.limit -= retrieve self.limit -= retrieve
@@ -232,7 +203,6 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]):
else: else:
await self.users.put(User(state=self.state, data=element)) await self.users.put(User(state=self.state, data=element))
class HistoryIterator(_AsyncIterator['Message']): class HistoryIterator(_AsyncIterator['Message']):
"""Iterator for receiving a channel's message history. """Iterator for receiving a channel's message history.
@@ -267,7 +237,8 @@ class HistoryIterator(_AsyncIterator['Message']):
``True`` if `after` is specified, otherwise ``False``. ``True`` if `after` is specified, otherwise ``False``.
""" """
def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None): def __init__(self, messageable, limit,
before=None, after=None, around=None, oldest_first=None):
if isinstance(before, datetime.datetime): if isinstance(before, datetime.datetime):
before = Object(id=time_snowflake(before, high=False)) before = Object(id=time_snowflake(before, high=False))
@@ -301,7 +272,7 @@ class HistoryIterator(_AsyncIterator['Message']):
elif self.limit == 101: elif self.limit == 101:
self.limit = 100 # Thanks discord self.limit = 100 # Thanks discord
self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore self._retrieve_messages = self._retrieve_messages_around_strategy
if self.before and self.after: if self.before and self.after:
self._filter = lambda m: self.after.id < int(m['id']) < self.before.id self._filter = lambda m: self.after.id < int(m['id']) < self.before.id
elif self.before: elif self.before:
@@ -310,15 +281,15 @@ class HistoryIterator(_AsyncIterator['Message']):
self._filter = lambda m: self.after.id < int(m['id']) self._filter = lambda m: self.after.id < int(m['id'])
else: else:
if self.reverse: if self.reverse:
self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore self._retrieve_messages = self._retrieve_messages_after_strategy
if self.before: if (self.before):
self._filter = lambda m: int(m['id']) < self.before.id self._filter = lambda m: int(m['id']) < self.before.id
else: else:
self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore self._retrieve_messages = self._retrieve_messages_before_strategy
if self.after and self.after != OLDEST_OBJECT: if (self.after and self.after != OLDEST_OBJECT):
self._filter = lambda m: int(m['id']) > self.after.id self._filter = lambda m: int(m['id']) > self.after.id
async def next(self) -> Message: async def next(self) -> T:
if self.messages.empty(): if self.messages.empty():
await self.fill_messages() await self.fill_messages()
@@ -345,7 +316,7 @@ class HistoryIterator(_AsyncIterator['Message']):
if self._get_retrieve(): if self._get_retrieve():
data = await self._retrieve_messages(self.retrieve) data = await self._retrieve_messages(self.retrieve)
if len(data) < 100: if len(data) < 100:
self.limit = 0 # terminate the infinite loop self.limit = 0 # terminate the infinite loop
if self.reverse: if self.reverse:
data = reversed(data) data = reversed(data)
@@ -356,14 +327,14 @@ class HistoryIterator(_AsyncIterator['Message']):
for element in data: for element in data:
await self.messages.put(self.state.create_message(channel=channel, data=element)) await self.messages.put(self.state.create_message(channel=channel, data=element))
async def _retrieve_messages(self, retrieve) -> List[Message]: async def _retrieve_messages(self, retrieve):
"""Retrieve messages and update next parameters.""" """Retrieve messages and update next parameters."""
raise NotImplementedError pass
async def _retrieve_messages_before_strategy(self, retrieve): async def _retrieve_messages_before_strategy(self, retrieve):
"""Retrieve messages using before parameter.""" """Retrieve messages using before parameter."""
before = self.before.id if self.before else None before = self.before.id if self.before else None
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before) data = await self.logs_from(self.channel.id, retrieve, before=before)
if len(data): if len(data):
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
@@ -373,7 +344,7 @@ class HistoryIterator(_AsyncIterator['Message']):
async def _retrieve_messages_after_strategy(self, retrieve): async def _retrieve_messages_after_strategy(self, retrieve):
"""Retrieve messages using after parameter.""" """Retrieve messages using after parameter."""
after = self.after.id if self.after else None after = self.after.id if self.after else None
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after) data = await self.logs_from(self.channel.id, retrieve, after=after)
if len(data): if len(data):
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
@@ -384,12 +355,11 @@ class HistoryIterator(_AsyncIterator['Message']):
"""Retrieve messages using around parameter.""" """Retrieve messages using around parameter."""
if self.around: if self.around:
around = self.around.id if self.around else None around = self.around.id if self.around else None
data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around) data = await self.logs_from(self.channel.id, retrieve, around=around)
self.around = None self.around = None
return data return data
return [] return []
class AuditLogIterator(_AsyncIterator['AuditLogEntry']): class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None):
if isinstance(before, datetime.datetime): if isinstance(before, datetime.datetime):
@@ -397,6 +367,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
if isinstance(after, datetime.datetime): if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True)) after = Object(id=time_snowflake(after, high=True))
if oldest_first is None: if oldest_first is None:
self.reverse = after is not None self.reverse = after is not None
else: else:
@@ -413,10 +384,12 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
self._users = {} self._users = {}
self._state = guild._state self._state = guild._state
self._filter = None # entry dict -> bool self._filter = None # entry dict -> bool
self.entries = asyncio.Queue() self.entries = asyncio.Queue()
if self.reverse: if self.reverse:
self._strategy = self._after_strategy self._strategy = self._after_strategy
if self.before: if self.before:
@@ -428,9 +401,8 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
async def _before_strategy(self, retrieve): async def _before_strategy(self, retrieve):
before = self.before.id if self.before else None before = self.before.id if self.before else None
data: AuditLogPayload = await self.request( data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before action_type=self.action_type, before=before)
)
entries = data.get('audit_log_entries', []) entries = data.get('audit_log_entries', [])
if len(data) and entries: if len(data) and entries:
@@ -441,9 +413,8 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
async def _after_strategy(self, retrieve): async def _after_strategy(self, retrieve):
after = self.after.id if self.after else None after = self.after.id if self.after else None
data: AuditLogPayload = await self.request( data = await self.request(self.guild.id, limit=retrieve, user_id=self.user_id,
self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after action_type=self.action_type, after=after)
)
entries = data.get('audit_log_entries', []) entries = data.get('audit_log_entries', [])
if len(data) and entries: if len(data) and entries:
if self.limit is not None: if self.limit is not None:
@@ -451,7 +422,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
self.after = Object(id=int(entries[0]['id'])) self.after = Object(id=int(entries[0]['id']))
return data.get('users', []), entries return data.get('users', []), entries
async def next(self) -> AuditLogEntry: async def next(self) -> T:
if self.entries.empty(): if self.entries.empty():
await self._fill() await self._fill()
@@ -475,7 +446,7 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']):
if self._get_retrieve(): if self._get_retrieve():
users, data = await self._strategy(self.retrieve) users, data = await self._strategy(self.retrieve)
if len(data) < 100: if len(data) < 100:
self.limit = 0 # terminate the infinite loop self.limit = 0 # terminate the infinite loop
if self.reverse: if self.reverse:
data = reversed(data) data = reversed(data)
@@ -522,7 +493,6 @@ class GuildIterator(_AsyncIterator['Guild']):
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]] after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Object after which all guilds must be. Object after which all guilds must be.
""" """
def __init__(self, bot, limit, before=None, after=None): def __init__(self, bot, limit, before=None, after=None):
if isinstance(before, datetime.datetime): if isinstance(before, datetime.datetime):
@@ -542,14 +512,14 @@ class GuildIterator(_AsyncIterator['Guild']):
self.guilds = asyncio.Queue() self.guilds = asyncio.Queue()
if self.before and self.after: if self.before and self.after:
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore self._retrieve_guilds = self._retrieve_guilds_before_strategy
self._filter = lambda m: int(m['id']) > self.after.id self._filter = lambda m: int(m['id']) > self.after.id
elif self.after: elif self.after:
self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore self._retrieve_guilds = self._retrieve_guilds_after_strategy
else: else:
self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore self._retrieve_guilds = self._retrieve_guilds_before_strategy
async def next(self) -> Guild: async def next(self) -> T:
if self.guilds.empty(): if self.guilds.empty():
await self.fill_guilds() await self.fill_guilds()
@@ -569,7 +539,6 @@ class GuildIterator(_AsyncIterator['Guild']):
def create_guild(self, data): def create_guild(self, data):
from .guild import Guild from .guild import Guild
return Guild(state=self.state, data=data) return Guild(state=self.state, data=data)
async def fill_guilds(self): async def fill_guilds(self):
@@ -584,14 +553,14 @@ class GuildIterator(_AsyncIterator['Guild']):
for element in data: for element in data:
await self.guilds.put(self.create_guild(element)) await self.guilds.put(self.create_guild(element))
async def _retrieve_guilds(self, retrieve) -> List[Guild]: async def _retrieve_guilds(self, retrieve):
"""Retrieve guilds and update next parameters.""" """Retrieve guilds and update next parameters."""
raise NotImplementedError pass
async def _retrieve_guilds_before_strategy(self, retrieve): async def _retrieve_guilds_before_strategy(self, retrieve):
"""Retrieve guilds using before parameter.""" """Retrieve guilds using before parameter."""
before = self.before.id if self.before else None before = self.before.id if self.before else None
data: List[GuildPayload] = await self.get_guilds(retrieve, before=before) data = await self.get_guilds(retrieve, before=before)
if len(data): if len(data):
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
@@ -601,14 +570,13 @@ class GuildIterator(_AsyncIterator['Guild']):
async def _retrieve_guilds_after_strategy(self, retrieve): async def _retrieve_guilds_after_strategy(self, retrieve):
"""Retrieve guilds using after parameter.""" """Retrieve guilds using after parameter."""
after = self.after.id if self.after else None after = self.after.id if self.after else None
data: List[GuildPayload] = await self.get_guilds(retrieve, after=after) data = await self.get_guilds(retrieve, after=after)
if len(data): if len(data):
if self.limit is not None: if self.limit is not None:
self.limit -= retrieve self.limit -= retrieve
self.after = Object(id=int(data[0]['id'])) self.after = Object(id=int(data[0]['id']))
return data return data
class MemberIterator(_AsyncIterator['Member']): class MemberIterator(_AsyncIterator['Member']):
def __init__(self, guild, limit=1000, after=None): def __init__(self, guild, limit=1000, after=None):
@@ -623,7 +591,7 @@ class MemberIterator(_AsyncIterator['Member']):
self.get_members = self.state.http.get_members self.get_members = self.state.http.get_members
self.members = asyncio.Queue() self.members = asyncio.Queue()
async def next(self) -> Member: async def next(self) -> T:
if self.members.empty(): if self.members.empty():
await self.fill_members() await self.fill_members()
@@ -650,7 +618,7 @@ class MemberIterator(_AsyncIterator['Member']):
return return
if len(data) < 1000: if len(data) < 1000:
self.limit = 0 # terminate loop self.limit = 0 # terminate loop
self.after = Object(id=int(data[-1]['user']['id'])) self.after = Object(id=int(data[-1]['user']['id']))
@@ -659,95 +627,4 @@ class MemberIterator(_AsyncIterator['Member']):
def create_member(self, data): def create_member(self, data):
from .member import Member from .member import Member
return Member(data=data, guild=self.guild, state=self.state) return Member(data=data, guild=self.guild, state=self.state)
class ArchivedThreadIterator(_AsyncIterator['Thread']):
def __init__(
self,
channel_id: int,
guild: Guild,
limit: Optional[int],
joined: bool,
private: bool,
before: Optional[Union[Snowflake, datetime.datetime]] = None,
):
self.channel_id = channel_id
self.guild = guild
self.limit = limit
self.joined = joined
self.private = private
self.http = guild._state.http
if joined and not private:
raise ValueError('Cannot iterate over joined public archived threads')
self.before: Optional[str]
if before is None:
self.before = None
elif isinstance(before, datetime.datetime):
if joined:
self.before = str(time_snowflake(before, high=False))
else:
self.before = before.isoformat()
else:
if joined:
self.before = str(before.id)
else:
self.before = snowflake_time(before.id).isoformat()
self.update_before: Callable[[ThreadPayload], str] = self.get_archive_timestamp
if joined:
self.endpoint = self.http.get_joined_private_archived_threads
self.update_before = self.get_thread_id
elif private:
self.endpoint = self.http.get_private_archived_threads
else:
self.endpoint = self.http.get_public_archived_threads
self.queue: asyncio.Queue[Thread] = asyncio.Queue()
self.has_more: bool = True
async def next(self) -> Thread:
if self.queue.empty():
await self.fill_queue()
try:
return self.queue.get_nowait()
except asyncio.QueueEmpty:
raise NoMoreItems()
@staticmethod
def get_archive_timestamp(data: ThreadPayload) -> str:
return data['thread_metadata']['archive_timestamp']
@staticmethod
def get_thread_id(data: ThreadPayload) -> str:
return data['id'] # type: ignore
async def fill_queue(self) -> None:
if not self.has_more:
raise NoMoreItems()
limit = 50 if self.limit is None else max(self.limit, 50)
data = await self.endpoint(self.channel_id, before=self.before, limit=limit)
# This stuff is obviously WIP because 'members' is always empty
threads: List[ThreadPayload] = data.get('threads', [])
for d in reversed(threads):
self.queue.put_nowait(self.create_thread(d))
self.has_more = data.get('has_more', False)
if self.limit is not None:
self.limit -= len(threads)
if self.limit <= 0:
self.has_more = False
if self.has_more:
self.before = self.update_before(threads[-1])
def create_thread(self, data: ThreadPayload) -> Thread:
from .threads import Thread
return Thread(guild=self.guild, state=self.guild._state, data=data)

View File

@@ -22,22 +22,18 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import datetime import datetime
import inspect import inspect
import itertools import itertools
import sys import sys
from operator import attrgetter from operator import attrgetter
from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload
import discord.abc import discord.abc
from . import utils from . import utils
from .asset import Asset from .errors import ClientException
from .utils import MISSING from .user import BaseUser, User
from .user import BaseUser, User, _UserTag from .activity import create_activity
from .activity import create_activity, ActivityTypes
from .permissions import Permissions from .permissions import Permissions
from .enums import Status, try_enum from .enums import Status, try_enum
from .colour import Colour from .colour import Colour
@@ -48,27 +44,6 @@ __all__ = (
'Member', 'Member',
) )
if TYPE_CHECKING:
from .asset import Asset
from .channel import DMChannel, VoiceChannel, StageChannel
from .flags import PublicUserFlags
from .guild import Guild
from .types.activity import PartialPresenceUpdate
from .types.member import (
MemberWithUser as MemberWithUserPayload,
Member as MemberPayload,
UserWithMember as UserWithMemberPayload,
)
from .types.user import User as UserPayload
from .abc import Snowflake
from .state import ConnectionState
from .message import Message
from .role import Role
from .types.voice import VoiceState as VoiceStatePayload
VocalGuildChannel = Union[VoiceChannel, StageChannel]
class VoiceState: class VoiceState:
"""Represents a Discord user's voice state. """Represents a Discord user's voice state.
@@ -112,49 +87,38 @@ class VoiceState:
is not currently in a voice channel. is not currently in a voice channel.
""" """
__slots__ = ( __slots__ = ('session_id', 'deaf', 'mute', 'self_mute',
'session_id', 'self_stream', 'self_video', 'self_deaf', 'afk', 'channel',
'deaf', 'requested_to_speak_at', 'suppress')
'mute',
'self_mute',
'self_stream',
'self_video',
'self_deaf',
'afk',
'channel',
'requested_to_speak_at',
'suppress',
)
def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None): def __init__(self, *, data, channel=None):
self.session_id: str = data.get('session_id') self.session_id = data.get('session_id')
self._update(data, channel) self._update(data, channel)
def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]): def _update(self, data, channel):
self.self_mute: bool = data.get('self_mute', False) self.self_mute = data.get('self_mute', False)
self.self_deaf: bool = data.get('self_deaf', False) self.self_deaf = data.get('self_deaf', False)
self.self_stream: bool = data.get('self_stream', False) self.self_stream = data.get('self_stream', False)
self.self_video: bool = data.get('self_video', False) self.self_video = data.get('self_video', False)
self.afk: bool = data.get('suppress', False) self.afk = data.get('suppress', False)
self.mute: bool = data.get('mute', False) self.mute = data.get('mute', False)
self.deaf: bool = data.get('deaf', False) self.deaf = data.get('deaf', False)
self.suppress: bool = data.get('suppress', False) self.suppress = data.get('suppress', False)
self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(data.get('request_to_speak_timestamp')) self.requested_to_speak_at = utils.parse_time(data.get('request_to_speak_timestamp'))
self.channel: Optional[VocalGuildChannel] = channel self.channel = channel
def __repr__(self) -> str: def __repr__(self):
attrs = [ attrs = [
('self_mute', self.self_mute), ('self_mute', self.self_mute),
('self_deaf', self.self_deaf), ('self_deaf', self.self_deaf),
('self_stream', self.self_stream), ('self_stream', self.self_stream),
('suppress', self.suppress), ('suppress', self.suppress),
('requested_to_speak_at', self.requested_to_speak_at), ('requested_to_speak_at', self.requested_to_speak_at),
('channel', self.channel), ('channel', self.channel)
] ]
inner = ' '.join('%s=%r' % t for t in attrs) inner = ' '.join('%s=%r' % t for t in attrs)
return f'<{self.__class__.__name__} {inner}>' return f'<{self.__class__.__name__} {inner}>'
def flatten_user(cls): def flatten_user(cls):
for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()):
# ignore private/special methods # ignore private/special methods
@@ -178,12 +142,9 @@ def flatten_user(cls):
def generate_function(x): def generate_function(x):
# We want sphinx to properly show coroutine functions as coroutines # We want sphinx to properly show coroutine functions as coroutines
if inspect.iscoroutinefunction(value): if inspect.iscoroutinefunction(value):
async def general(self, *args, **kwargs):
async def general(self, *args, **kwargs): # type: ignore
return await getattr(self._user, x)(*args, **kwargs) return await getattr(self._user, x)(*args, **kwargs)
else: else:
def general(self, *args, **kwargs): def general(self, *args, **kwargs):
return getattr(self._user, x)(*args, **kwargs) return getattr(self._user, x)(*args, **kwargs)
@@ -196,12 +157,10 @@ def flatten_user(cls):
return cls return cls
_BaseUser = discord.abc.User
M = TypeVar('M', bound='Member')
@flatten_user @flatten_user
class Member(discord.abc.Messageable, _UserTag): class Member(discord.abc.Messageable, _BaseUser):
"""Represents a Discord member to a :class:`Guild`. """Represents a Discord member to a :class:`Guild`.
This implements a lot of the functionality of :class:`User`. This implements a lot of the functionality of :class:`User`.
@@ -250,98 +209,83 @@ class Member(discord.abc.Messageable, _UserTag):
.. versionadded:: 1.6 .. versionadded:: 1.6
premium_since: Optional[:class:`datetime.datetime`] premium_since: Optional[:class:`datetime.datetime`]
An aware datetime object that specifies the date and time in UTC when the member used their An aware datetime object that specifies the date and time in UTC when the member used their
"Nitro boost" on the guild, if available. This could be ``None``. Nitro boost on the guild, if available. This could be ``None``.
""" """
__slots__ = ( __slots__ = ('_roles', 'joined_at', 'premium_since', '_client_status',
'_roles', 'activities', 'guild', 'pending', 'nick', '_user', '_state')
'joined_at',
'premium_since',
'activities',
'guild',
'pending',
'nick',
'_client_status',
'_user',
'_state',
'_avatar',
)
if TYPE_CHECKING: def __init__(self, *, data, guild, state):
name: str self._state = state
id: int self._user = state.store_user(data['user'])
discriminator: str self.guild = guild
bot: bool self.joined_at = utils.parse_time(data.get('joined_at'))
system: bool self.premium_since = utils.parse_time(data.get('premium_since'))
created_at: datetime.datetime self._update_roles(data)
default_avatar: Asset self._client_status = {
avatar: Optional[Asset] None: 'offline'
dm_channel: Optional[DMChannel] }
create_dm = User.create_dm self.activities = tuple(map(create_activity, data.get('activities', [])))
mutual_guilds: List[Guild] self.nick = data.get('nick', None)
public_flags: PublicUserFlags self.pending = data.get('pending', False)
banner: Optional[Asset]
accent_color: Optional[Colour]
accent_colour: Optional[Colour]
def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState): def __str__(self):
self._state: ConnectionState = state
self._user: User = state.store_user(data['user'])
self.guild: Guild = guild
self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at'))
self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since'))
self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles']))
self._client_status: Dict[Optional[str], str] = {None: 'offline'}
self.activities: Tuple[ActivityTypes, ...] = tuple()
self.nick: Optional[str] = data.get('nick', None)
self.pending: bool = data.get('pending', False)
self._avatar: Optional[str] = data.get('avatar')
def __str__(self) -> str:
return str(self._user) return str(self._user)
def __repr__(self) -> str: def __int__(self):
return ( return self.id
f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}'
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
)
def __eq__(self, other: Any) -> bool: def __repr__(self):
return isinstance(other, _UserTag) and other.id == self.id return f'<Member id={self._user.id} name={self._user.name!r} discriminator={self._user.discriminator!r}' \
f' bot={self._user.bot} nick={self.nick!r} guild={self.guild!r}>'
def __ne__(self, other: Any) -> bool: def __eq__(self, other):
return isinstance(other, _BaseUser) and other.id == self.id
def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self):
return hash(self._user) return hash(self._user)
@classmethod @classmethod
def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M: def _from_message(cls, *, message, data):
author = message.author author = message.author
data['user'] = author._to_minimal_user_json() # type: ignore data['user'] = author._to_minimal_user_json()
return cls(data=data, guild=message.guild, state=message._state) # type: ignore return cls(data=data, guild=message.guild, state=message._state)
def _update_from_message(self, data: MemberPayload) -> None: def _update_from_message(self, data):
self.joined_at = utils.parse_time(data.get('joined_at')) self.joined_at = utils.parse_time(data.get('joined_at'))
self.premium_since = utils.parse_time(data.get('premium_since')) self.premium_since = utils.parse_time(data.get('premium_since'))
self._roles = utils.SnowflakeList(map(int, data['roles'])) self._update_roles(data)
self.nick = data.get('nick', None) self.nick = data.get('nick', None)
self.pending = data.get('pending', False) self.pending = data.get('pending', False)
@classmethod @classmethod
def _try_upgrade(cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState) -> Union[User, M]: def _try_upgrade(cls, *, data, guild, state):
# A User object with a 'member' key # A User object with a 'member' key
try: try:
member_data = data.pop('member') member_data = data.pop('member')
except KeyError: except KeyError:
return state.create_user(data) return state.store_user(data)
else: else:
member_data['user'] = data # type: ignore member_data['user'] = data
return cls(data=member_data, guild=guild, state=state) # type: ignore return cls(data=member_data, guild=guild, state=state)
@classmethod @classmethod
def _copy(cls: Type[M], member: M) -> M: def _from_presence_update(cls, *, data, guild, state):
self: M = cls.__new__(cls) # to bypass __init__ clone = cls(data=data, guild=guild, state=state)
to_return = cls(data=data, guild=guild, state=state)
to_return._client_status = {
sys.intern(key): sys.intern(value)
for key, value in data.get('client_status', {}).items()
}
to_return._client_status[None] = sys.intern(data['status'])
return to_return, clone
@classmethod
def _copy(cls, member):
self = cls.__new__(cls) # to bypass __init__
self._roles = utils.SnowflakeList(member._roles, is_sorted=True) self._roles = utils.SnowflakeList(member._roles, is_sorted=True)
self.joined_at = member.joined_at self.joined_at = member.joined_at
@@ -352,7 +296,6 @@ class Member(discord.abc.Messageable, _UserTag):
self.pending = member.pending self.pending = member.pending
self.activities = member.activities self.activities = member.activities
self._state = member._state self._state = member._state
self._avatar = member._avatar
# Reference will not be copied unless necessary by PRESENCE_UPDATE # Reference will not be copied unless necessary by PRESENCE_UPDATE
# See below # See below
@@ -363,7 +306,10 @@ class Member(discord.abc.Messageable, _UserTag):
ch = await self.create_dm() ch = await self.create_dm()
return ch return ch
def _update(self, data: MemberPayload) -> None: def _update_roles(self, data):
self._roles = utils.SnowflakeList(map(int, data['roles']))
def _update(self, data):
# the nickname change is optional, # the nickname change is optional,
# if it isn't in the payload then it didn't change # if it isn't in the payload then it didn't change
try: try:
@@ -377,38 +323,38 @@ class Member(discord.abc.Messageable, _UserTag):
pass pass
self.premium_since = utils.parse_time(data.get('premium_since')) self.premium_since = utils.parse_time(data.get('premium_since'))
self._roles = utils.SnowflakeList(map(int, data['roles'])) self._update_roles(data)
self._avatar = data.get('avatar')
def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]: def _presence_update(self, data, user):
self.activities = tuple(map(create_activity, data['activities'])) self.activities = tuple(map(create_activity, data.get('activities', [])))
self._client_status = { self._client_status = {
sys.intern(key): sys.intern(value) for key, value in data.get('client_status', {}).items() # type: ignore sys.intern(key): sys.intern(value)
for key, value in data.get('client_status', {}).items()
} }
self._client_status[None] = sys.intern(data['status']) self._client_status[None] = sys.intern(data['status'])
if len(user) > 1: if len(user) > 1:
return self._update_inner_user(user) return self._update_inner_user(user)
return None return False
def _update_inner_user(self, user: UserPayload) -> Optional[Tuple[User, User]]: def _update_inner_user(self, user):
u = self._user u = self._user
original = (u.name, u._avatar, u.discriminator, u._public_flags) original = (u.name, u.avatar, u.discriminator, u._public_flags)
# These keys seem to always be available # These keys seem to always be available
modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0)) modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0))
if original != modified: if original != modified:
to_return = User._copy(self._user) to_return = User._copy(self._user)
u.name, u._avatar, u.discriminator, u._public_flags = modified u.name, u.avatar, u.discriminator, u._public_flags = modified
# Signal to dispatch on_user_update # Signal to dispatch on_user_update
return to_return, u return to_return, u
@property @property
def status(self) -> Status: def status(self):
""":class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead.""" """:class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead."""
return try_enum(Status, self._client_status[None]) return try_enum(Status, self._client_status[None])
@property @property
def raw_status(self) -> str: def raw_status(self):
""":class:`str`: The member's overall status as a string value. """:class:`str`: The member's overall status as a string value.
.. versionadded:: 1.5 .. versionadded:: 1.5
@@ -416,31 +362,31 @@ class Member(discord.abc.Messageable, _UserTag):
return self._client_status[None] return self._client_status[None]
@status.setter @status.setter
def status(self, value: Status) -> None: def status(self, value):
# internal use only # internal use only
self._client_status[None] = str(value) self._client_status[None] = str(value)
@property @property
def mobile_status(self) -> Status: def mobile_status(self):
""":class:`Status`: The member's status on a mobile device, if applicable.""" """:class:`Status`: The member's status on a mobile device, if applicable."""
return try_enum(Status, self._client_status.get('mobile', 'offline')) return try_enum(Status, self._client_status.get('mobile', 'offline'))
@property @property
def desktop_status(self) -> Status: def desktop_status(self):
""":class:`Status`: The member's status on the desktop client, if applicable.""" """:class:`Status`: The member's status on the desktop client, if applicable."""
return try_enum(Status, self._client_status.get('desktop', 'offline')) return try_enum(Status, self._client_status.get('desktop', 'offline'))
@property @property
def web_status(self) -> Status: def web_status(self):
""":class:`Status`: The member's status on the web client, if applicable.""" """:class:`Status`: The member's status on the web client, if applicable."""
return try_enum(Status, self._client_status.get('web', 'offline')) return try_enum(Status, self._client_status.get('web', 'offline'))
def is_on_mobile(self) -> bool: def is_on_mobile(self):
""":class:`bool`: A helper function that determines if a member is active on a mobile device.""" """:class:`bool`: A helper function that determines if a member is active on a mobile device."""
return 'mobile' in self._client_status return 'mobile' in self._client_status
@property @property
def colour(self) -> Colour: def colour(self):
""":class:`Colour`: A property that returns a colour denoting the rendered colour """:class:`Colour`: A property that returns a colour denoting the rendered colour
for the member. If the default colour is the one rendered then an instance for the member. If the default colour is the one rendered then an instance
of :meth:`Colour.default` is returned. of :meth:`Colour.default` is returned.
@@ -448,7 +394,7 @@ class Member(discord.abc.Messageable, _UserTag):
There is an alias for this named :attr:`color`. There is an alias for this named :attr:`color`.
""" """
roles = self.roles[1:] # remove @everyone roles = self.roles[1:] # remove @everyone
# highest order of the colour is the one that gets rendered. # highest order of the colour is the one that gets rendered.
# if the highest is the default colour then the next one with a colour # if the highest is the default colour then the next one with a colour
@@ -459,7 +405,7 @@ class Member(discord.abc.Messageable, _UserTag):
return Colour.default() return Colour.default()
@property @property
def color(self) -> Colour: def color(self):
""":class:`Colour`: A property that returns a color denoting the rendered color for """:class:`Colour`: A property that returns a color denoting the rendered color for
the member. If the default color is the one rendered then an instance of :meth:`Colour.default` the member. If the default color is the one rendered then an instance of :meth:`Colour.default`
is returned. is returned.
@@ -469,7 +415,7 @@ class Member(discord.abc.Messageable, _UserTag):
return self.colour return self.colour
@property @property
def roles(self) -> List[Role]: def roles(self):
"""List[:class:`Role`]: A :class:`list` of :class:`Role` that the member belongs to. Note """List[:class:`Role`]: A :class:`list` of :class:`Role` that the member belongs to. Note
that the first element of this list is always the default '@everyone' that the first element of this list is always the default '@everyone'
role. role.
@@ -487,14 +433,14 @@ class Member(discord.abc.Messageable, _UserTag):
return result return result
@property @property
def mention(self) -> str: def mention(self):
""":class:`str`: Returns a string that allows you to mention the member.""" """:class:`str`: Returns a string that allows you to mention the member."""
if self.nick: if self.nick:
return f'<@!{self._user.id}>' return f'<@!{self._user.id}>'
return f'<@{self._user.id}>' return f'<@{self._user.id}>'
@property @property
def display_name(self) -> str: def display_name(self):
""":class:`str`: Returns the user's display name. """:class:`str`: Returns the user's display name.
For regular users this is just their username, but For regular users this is just their username, but
@@ -504,36 +450,13 @@ class Member(discord.abc.Messageable, _UserTag):
return self.nick or self.name return self.nick or self.name
@property @property
def display_avatar(self) -> Asset: def activity(self):
""":class:`Asset`: Returns the member's display avatar. """Union[:class:`BaseActivity`, :class:`Spotify`]: Returns the primary
For regular members this is just their avatar, but
if they have a guild specific avatar then that
is returned instead.
.. versionadded:: 2.0
"""
return self.guild_avatar or self._user.avatar or self._user.default_avatar
@property
def guild_avatar(self) -> Optional[Asset]:
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the guild avatar
the member has. If unavailable, ``None`` is returned.
.. versionadded:: 2.0
"""
if self._avatar is None:
return None
return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar)
@property
def activity(self) -> Optional[ActivityTypes]:
"""Optional[Union[:class:`BaseActivity`, :class:`Spotify`]]: Returns the primary
activity the user is currently doing. Could be ``None`` if no activity is being done. activity the user is currently doing. Could be ``None`` if no activity is being done.
.. note:: .. note::
Due to a Discord API limitation, this may be ``None`` if Due to a Discord API limitation, this may be ``None`` if
the user is listening to a song on Spotify with a title longer the user is listening to a song on Spotify with a title longer
than 128 characters. See :issue:`1738` for more information. than 128 characters. See :issue:`1738` for more information.
@@ -544,7 +467,7 @@ class Member(discord.abc.Messageable, _UserTag):
if self.activities: if self.activities:
return self.activities[0] return self.activities[0]
def mentioned_in(self, message: Message) -> bool: def mentioned_in(self, message):
"""Checks if the member is mentioned in the specified message. """Checks if the member is mentioned in the specified message.
Parameters Parameters
@@ -565,8 +488,29 @@ class Member(discord.abc.Messageable, _UserTag):
return any(self._roles.has(role.id) for role in message.role_mentions) return any(self._roles.has(role.id) for role in message.role_mentions)
def permissions_in(self, channel):
"""An alias for :meth:`abc.GuildChannel.permissions_for`.
Basically equivalent to:
.. code-block:: python3
channel.permissions_for(self)
Parameters
-----------
channel: :class:`abc.GuildChannel`
The channel to check your permissions for.
Returns
-------
:class:`Permissions`
The resolved permissions for the member.
"""
return channel.permissions_for(self)
@property @property
def top_role(self) -> Role: def top_role(self):
""":class:`Role`: Returns the member's highest role. """:class:`Role`: Returns the member's highest role.
This is useful for figuring where a member stands in the role This is useful for figuring where a member stands in the role
@@ -579,13 +523,14 @@ class Member(discord.abc.Messageable, _UserTag):
return max(guild.get_role(rid) or guild.default_role for rid in self._roles) return max(guild.get_role(rid) or guild.default_role for rid in self._roles)
@property @property
def guild_permissions(self) -> Permissions: def guild_permissions(self):
""":class:`Permissions`: Returns the member's guild permissions. """:class:`Permissions`: Returns the member's guild permissions.
This only takes into consideration the guild permissions This only takes into consideration the guild permissions
and not most of the implied permissions or any of the and not most of the implied permissions or any of the
channel permission overwrites. For 100% accurate permission channel permission overwrites. For 100% accurate permission
calculation, please use :meth:`abc.GuildChannel.permissions_for`. calculation, please use either :meth:`permissions_in` or
:meth:`abc.GuildChannel.permissions_for`.
This does take into consideration guild ownership and the This does take into consideration guild ownership and the
administrator implication. administrator implication.
@@ -604,47 +549,32 @@ class Member(discord.abc.Messageable, _UserTag):
return base return base
@property @property
def voice(self) -> Optional[VoiceState]: def voice(self):
"""Optional[:class:`VoiceState`]: Returns the member's current voice state.""" """Optional[:class:`VoiceState`]: Returns the member's current voice state."""
return self.guild._voice_state_for(self._user.id) return self.guild._voice_state_for(self._user.id)
async def ban( async def ban(self, **kwargs):
self,
*,
delete_message_days: Literal[0, 1, 2, 3, 4, 5, 6, 7] = 1,
reason: Optional[str] = None,
) -> None:
"""|coro| """|coro|
Bans this member. Equivalent to :meth:`Guild.ban`. Bans this member. Equivalent to :meth:`Guild.ban`.
""" """
await self.guild.ban(self, reason=reason, delete_message_days=delete_message_days) await self.guild.ban(self, **kwargs)
async def unban(self, *, reason: Optional[str] = None) -> None: async def unban(self, *, reason=None):
"""|coro| """|coro|
Unbans this member. Equivalent to :meth:`Guild.unban`. Unbans this member. Equivalent to :meth:`Guild.unban`.
""" """
await self.guild.unban(self, reason=reason) await self.guild.unban(self, reason=reason)
async def kick(self, *, reason: Optional[str] = None) -> None: async def kick(self, *, reason=None):
"""|coro| """|coro|
Kicks this member. Equivalent to :meth:`Guild.kick`. Kicks this member. Equivalent to :meth:`Guild.kick`.
""" """
await self.guild.kick(self, reason=reason) await self.guild.kick(self, reason=reason)
async def edit( async def edit(self, *, reason=None, **fields):
self,
*,
nick: Optional[str] = MISSING,
mute: bool = MISSING,
deafen: bool = MISSING,
suppress: bool = MISSING,
roles: List[discord.abc.Snowflake] = MISSING,
voice_channel: Optional[VocalGuildChannel] = MISSING,
reason: Optional[str] = None,
) -> Optional[Member]:
"""|coro| """|coro|
Edits the member's data. Edits the member's data.
@@ -670,9 +600,6 @@ class Member(discord.abc.Messageable, _UserTag):
.. versionchanged:: 1.1 .. versionchanged:: 1.1
Can now pass ``None`` to ``voice_channel`` to kick a member from voice. Can now pass ``None`` to ``voice_channel`` to kick a member from voice.
.. versionchanged:: 2.0
The newly member is now optionally returned, if applicable.
Parameters Parameters
----------- -----------
nick: Optional[:class:`str`] nick: Optional[:class:`str`]
@@ -686,7 +613,7 @@ class Member(discord.abc.Messageable, _UserTag):
.. versionadded:: 1.7 .. versionadded:: 1.7
roles: List[:class:`Role`] roles: Optional[List[:class:`Role`]]
The member's new list of roles. This *replaces* the roles. The member's new list of roles. This *replaces* the roles.
voice_channel: Optional[:class:`VoiceChannel`] voice_channel: Optional[:class:`VoiceChannel`]
The voice channel to move the member to. The voice channel to move the member to.
@@ -700,32 +627,34 @@ class Member(discord.abc.Messageable, _UserTag):
You do not have the proper permissions to the action requested. You do not have the proper permissions to the action requested.
HTTPException HTTPException
The operation failed. The operation failed.
Returns
--------
Optional[:class:`.Member`]
The newly updated member, if applicable. This is only returned
when certain fields are updated.
""" """
http = self._state.http http = self._state.http
guild_id = self.guild.id guild_id = self.guild.id
me = self._state.self_id == self.id me = self._state.self_id == self.id
payload: Dict[str, Any] = {} payload = {}
if nick is not MISSING: try:
nick = fields['nick']
except KeyError:
# nick not present so...
pass
else:
nick = nick or '' nick = nick or ''
if me: if me:
await http.change_my_nickname(guild_id, nick, reason=reason) await http.change_my_nickname(guild_id, nick, reason=reason)
else: else:
payload['nick'] = nick payload['nick'] = nick
if deafen is not MISSING: deafen = fields.get('deafen')
if deafen is not None:
payload['deaf'] = deafen payload['deaf'] = deafen
if mute is not MISSING: mute = fields.get('mute')
if mute is not None:
payload['mute'] = mute payload['mute'] = mute
if suppress is not MISSING: suppress = fields.get('suppress')
if suppress is not None:
voice_state_payload = { voice_state_payload = {
'channel_id': self.voice.channel.id, 'channel_id': self.voice.channel.id,
'suppress': suppress, 'suppress': suppress,
@@ -741,17 +670,26 @@ class Member(discord.abc.Messageable, _UserTag):
voice_state_payload['request_to_speak_timestamp'] = datetime.datetime.utcnow().isoformat() voice_state_payload['request_to_speak_timestamp'] = datetime.datetime.utcnow().isoformat()
await http.edit_voice_state(guild_id, self.id, voice_state_payload) await http.edit_voice_state(guild_id, self.id, voice_state_payload)
if voice_channel is not MISSING: try:
payload['channel_id'] = voice_channel and voice_channel.id vc = fields['voice_channel']
except KeyError:
pass
else:
payload['channel_id'] = vc and vc.id
if roles is not MISSING: try:
roles = fields['roles']
except KeyError:
pass
else:
payload['roles'] = tuple(r.id for r in roles) payload['roles'] = tuple(r.id for r in roles)
if payload: if payload:
data = await http.edit_member(guild_id, self.id, reason=reason, **payload) await http.edit_member(guild_id, self.id, reason=reason, **payload)
return Member(data=data, guild=self.guild, state=self._state)
async def request_to_speak(self) -> None: # TODO: wait for WS event for modify-in-place behaviour
async def request_to_speak(self):
"""|coro| """|coro|
Request to speak in the connected channel. Request to speak in the connected channel.
@@ -783,7 +721,7 @@ class Member(discord.abc.Messageable, _UserTag):
else: else:
await self._state.http.edit_my_voice_state(self.guild.id, payload) await self._state.http.edit_my_voice_state(self.guild.id, payload)
async def move_to(self, channel: VocalGuildChannel, *, reason: Optional[str] = None) -> None: async def move_to(self, channel, *, reason=None):
"""|coro| """|coro|
Moves a member to a new voice channel (they must be connected first). Moves a member to a new voice channel (they must be connected first).
@@ -806,7 +744,7 @@ class Member(discord.abc.Messageable, _UserTag):
""" """
await self.edit(voice_channel=channel, reason=reason) await self.edit(voice_channel=channel, reason=reason)
async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None: async def add_roles(self, *roles, reason=None, atomic=True):
r"""|coro| r"""|coro|
Gives the member a number of :class:`Role`\s. Gives the member a number of :class:`Role`\s.
@@ -845,7 +783,7 @@ class Member(discord.abc.Messageable, _UserTag):
for role in roles: for role in roles:
await req(guild_id, user_id, role.id, reason=reason) await req(guild_id, user_id, role.id, reason=reason)
async def remove_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None: async def remove_roles(self, *roles, reason=None, atomic=True):
r"""|coro| r"""|coro|
Removes :class:`Role`\s from this member. Removes :class:`Role`\s from this member.
@@ -875,7 +813,7 @@ class Member(discord.abc.Messageable, _UserTag):
""" """
if not atomic: if not atomic:
new_roles = [Object(id=r.id) for r in self.roles[1:]] # remove @everyone new_roles = [Object(id=r.id) for r in self.roles[1:]] # remove @everyone
for role in roles: for role in roles:
try: try:
new_roles.remove(Object(id=role.id)) new_roles.remove(Object(id=role.id))
@@ -889,20 +827,3 @@ class Member(discord.abc.Messageable, _UserTag):
user_id = self.id user_id = self.id
for role in roles: for role in roles:
await req(guild_id, user_id, role.id, reason=reason) await req(guild_id, user_id, role.id, reason=reason)
def get_role(self, role_id: int, /) -> Optional[Role]:
"""Returns a role with the given ID from roles which the member has.
.. versionadded:: 2.0
Parameters
-----------
role_id: :class:`int`
The ID to search for.
Returns
--------
Optional[:class:`Role`]
The role or ``None`` if not found in the member's roles.
"""
return self.guild.get_role(role_id) if self._roles.has(role_id) else None

View File

@@ -22,18 +22,10 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union
__all__ = ( __all__ = (
'AllowedMentions', 'AllowedMentions',
) )
if TYPE_CHECKING:
from .types.message import AllowedMentions as AllowedMentionsPayload
from .abc import Snowflake
class _FakeBool: class _FakeBool:
def __repr__(self): def __repr__(self):
return 'True' return 'True'
@@ -44,11 +36,7 @@ class _FakeBool:
def __bool__(self): def __bool__(self):
return True return True
default = _FakeBool()
default: Any = _FakeBool()
A = TypeVar('A', bound='AllowedMentions')
class AllowedMentions: class AllowedMentions:
"""A class that represents what mentions are allowed in a message. """A class that represents what mentions are allowed in a message.
@@ -82,21 +70,14 @@ class AllowedMentions:
__slots__ = ('everyone', 'users', 'roles', 'replied_user') __slots__ = ('everyone', 'users', 'roles', 'replied_user')
def __init__( def __init__(self, *, everyone=default, users=default, roles=default, replied_user=default):
self,
*,
everyone: bool = default,
users: Union[bool, List[Snowflake]] = default,
roles: Union[bool, List[Snowflake]] = default,
replied_user: bool = default,
):
self.everyone = everyone self.everyone = everyone
self.users = users self.users = users
self.roles = roles self.roles = roles
self.replied_user = replied_user self.replied_user = replied_user
@classmethod @classmethod
def all(cls: Type[A]) -> A: def all(cls):
"""A factory method that returns a :class:`AllowedMentions` with all fields explicitly set to ``True`` """A factory method that returns a :class:`AllowedMentions` with all fields explicitly set to ``True``
.. versionadded:: 1.5 .. versionadded:: 1.5
@@ -104,14 +85,14 @@ class AllowedMentions:
return cls(everyone=True, users=True, roles=True, replied_user=True) return cls(everyone=True, users=True, roles=True, replied_user=True)
@classmethod @classmethod
def none(cls: Type[A]) -> A: def none(cls):
"""A factory method that returns a :class:`AllowedMentions` with all fields set to ``False`` """A factory method that returns a :class:`AllowedMentions` with all fields set to ``False``
.. versionadded:: 1.5 .. versionadded:: 1.5
""" """
return cls(everyone=False, users=False, roles=False, replied_user=False) return cls(everyone=False, users=False, roles=False, replied_user=False)
def to_dict(self) -> AllowedMentionsPayload: def to_dict(self):
parse = [] parse = []
data = {} data = {}
@@ -132,9 +113,9 @@ class AllowedMentions:
data['replied_user'] = True data['replied_user'] = True
data['parse'] = parse data['parse'] = parse
return data # type: ignore return data
def merge(self, other: AllowedMentions) -> AllowedMentions: def merge(self, other):
# Creates a new AllowedMentions by merging from another one. # Creates a new AllowedMentions by merging from another one.
# Merge is done by using the 'self' values unless explicitly # Merge is done by using the 'self' values unless explicitly
# overridden by the 'other' values. # overridden by the 'other' values.
@@ -144,8 +125,5 @@ class AllowedMentions:
replied_user = self.replied_user if other.replied_user is default else other.replied_user replied_user = self.replied_user if other.replied_user is default else other.replied_user
return AllowedMentions(everyone=everyone, roles=roles, users=users, replied_user=replied_user) return AllowedMentions(everyone=everyone, roles=roles, users=users, replied_user=replied_user)
def __repr__(self) -> str: def __repr__(self):
return ( return '{0.__class__.__qualname__}(everyone={0.everyone}, users={0.users}, roles={0.roles}, replied_user={0.replied_user})'.format(self)
f'{self.__class__.__name__}(everyone={self.everyone}, '
f'users={self.users}, roles={self.roles}, replied_user={self.replied_user})'
)

File diff suppressed because it is too large Load Diff

View File

@@ -30,12 +30,10 @@ __all__ = (
class EqualityComparable: class EqualityComparable:
__slots__ = () __slots__ = ()
id: int def __eq__(self, other):
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and other.id == self.id return isinstance(other, self.__class__) and other.id == self.id
def __ne__(self, other: object) -> bool: def __ne__(self, other):
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return other.id != self.id return other.id != self.id
return True return True
@@ -43,5 +41,5 @@ class EqualityComparable:
class Hashable(EqualityComparable): class Hashable(EqualityComparable):
__slots__ = () __slots__ = ()
def __hash__(self) -> int: def __hash__(self):
return self.id >> 22 return self.id >> 22

View File

@@ -22,21 +22,9 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from . import utils from . import utils
from .mixins import Hashable from .mixins import Hashable
from typing import (
SupportsInt,
TYPE_CHECKING,
Union,
)
if TYPE_CHECKING:
import datetime
SupportsIntCast = Union[SupportsInt, str, bytes, bytearray]
__all__ = ( __all__ = (
'Object', 'Object',
) )
@@ -75,7 +63,7 @@ class Object(Hashable):
The ID of the object. The ID of the object.
""" """
def __init__(self, id: SupportsIntCast): def __init__(self, id):
try: try:
id = int(id) id = int(id)
except ValueError: except ValueError:
@@ -83,10 +71,10 @@ class Object(Hashable):
else: else:
self.id = id self.id = id
def __repr__(self) -> str: def __repr__(self):
return f'<Object id={self.id!r}>' return f'<Object id={self.id!r}>'
@property @property
def created_at(self) -> datetime.datetime: def created_at(self):
""":class:`datetime.datetime`: Returns the snowflake's creation time in UTC.""" """:class:`datetime.datetime`: Returns the snowflake's creation time in UTC."""
return utils.snowflake_time(self.id) return utils.snowflake_time(self.id)

View File

@@ -22,12 +22,8 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import struct import struct
from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional
from .errors import DiscordException from .errors import DiscordException
__all__ = ( __all__ = (
@@ -44,29 +40,22 @@ class OggError(DiscordException):
# https://tools.ietf.org/html/rfc7845 # https://tools.ietf.org/html/rfc7845
class OggPage: class OggPage:
_header: ClassVar[struct.Struct] = struct.Struct('<xBQIIIB') _header = struct.Struct('<xBQIIIB')
if TYPE_CHECKING:
flag: int
gran_pos: int
serial: int
pagenum: int
crc: int
segnum: int
def __init__(self, stream: IO[bytes]) -> None: def __init__(self, stream):
try: try:
header = stream.read(struct.calcsize(self._header.format)) header = stream.read(struct.calcsize(self._header.format))
self.flag, self.gran_pos, self.serial, \ self.flag, self.gran_pos, self.serial, \
self.pagenum, self.crc, self.segnum = self._header.unpack(header) self.pagenum, self.crc, self.segnum = self._header.unpack(header)
self.segtable: bytes = stream.read(self.segnum) self.segtable = stream.read(self.segnum)
bodylen = sum(struct.unpack('B'*self.segnum, self.segtable)) bodylen = sum(struct.unpack('B'*self.segnum, self.segtable))
self.data: bytes = stream.read(bodylen) self.data = stream.read(bodylen)
except Exception: except Exception:
raise OggError('bad data stream') from None raise OggError('bad data stream') from None
def iter_packets(self) -> Generator[Tuple[bytes, bool], None, None]: def iter_packets(self):
packetlen = offset = 0 packetlen = offset = 0
partial = True partial = True
@@ -85,10 +74,10 @@ class OggPage:
yield self.data[offset:], False yield self.data[offset:], False
class OggStream: class OggStream:
def __init__(self, stream: IO[bytes]) -> None: def __init__(self, stream):
self.stream: IO[bytes] = stream self.stream = stream
def _next_page(self) -> Optional[OggPage]: def _next_page(self):
head = self.stream.read(4) head = self.stream.read(4)
if head == b'OggS': if head == b'OggS':
return OggPage(self.stream) return OggPage(self.stream)
@@ -97,13 +86,13 @@ class OggStream:
else: else:
raise OggError('invalid header magic') raise OggError('invalid header magic')
def _iter_pages(self) -> Generator[OggPage, None, None]: def _iter_pages(self):
page = self._next_page() page = self._next_page()
while page: while page:
yield page yield page
page = self._next_page() page = self._next_page()
def iter_packets(self) -> Generator[bytes, None, None]: def iter_packets(self):
partial = b'' partial = b''
for page in self._iter_pages(): for page in self._iter_pages():
for data, complete in page.iter_packets(): for data, complete in page.iter_packets():

View File

@@ -22,10 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload
import array import array
import ctypes import ctypes
import ctypes.util import ctypes.util
@@ -35,24 +31,7 @@ import os.path
import struct import struct
import sys import sys
from .errors import DiscordException, InvalidArgument from .errors import DiscordException
if TYPE_CHECKING:
T = TypeVar('T')
BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full']
SIGNAL_CTL = Literal['auto', 'voice', 'music']
class BandCtl(TypedDict):
narrow: int
medium: int
wide: int
superwide: int
full: int
class SignalCtl(TypedDict):
auto: int
voice: int
music: int
__all__ = ( __all__ = (
'Encoder', 'Encoder',
@@ -60,7 +39,7 @@ __all__ = (
'OpusNotLoaded', 'OpusNotLoaded',
) )
_log = logging.getLogger(__name__) log = logging.getLogger(__name__)
c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int_ptr = ctypes.POINTER(ctypes.c_int)
c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_int16_ptr = ctypes.POINTER(ctypes.c_int16)
@@ -97,7 +76,7 @@ CTL_SET_SIGNAL = 4024
CTL_SET_GAIN = 4034 CTL_SET_GAIN = 4034
CTL_LAST_PACKET_DURATION = 4039 CTL_LAST_PACKET_DURATION = 4039
band_ctl: BandCtl = { band_ctl = {
'narrow': 1101, 'narrow': 1101,
'medium': 1102, 'medium': 1102,
'wide': 1103, 'wide': 1103,
@@ -105,22 +84,22 @@ band_ctl: BandCtl = {
'full': 1105, 'full': 1105,
} }
signal_ctl: SignalCtl = { signal_ctl = {
'auto': -1000, 'auto': -1000,
'voice': 3001, 'voice': 3001,
'music': 3002, 'music': 3002,
} }
def _err_lt(result: int, func: Callable, args: List) -> int: def _err_lt(result, func, args):
if result < OK: if result < OK:
_log.info('error has happened in %s', func.__name__) log.info('error has happened in %s', func.__name__)
raise OpusError(result) raise OpusError(result)
return result return result
def _err_ne(result: T, func: Callable, args: List) -> T: def _err_ne(result, func, args):
ret = args[-1]._obj ret = args[-1]._obj
if ret.value != OK: if ret.value != OK:
_log.info('error has happened in %s', func.__name__) log.info('error has happened in %s', func.__name__)
raise OpusError(ret.value) raise OpusError(ret.value)
return result return result
@@ -129,7 +108,7 @@ def _err_ne(result: T, func: Callable, args: List) -> T:
# The second one are the types of arguments it takes. # The second one are the types of arguments it takes.
# The third is the result type. # The third is the result type.
# The fourth is the error handler. # The fourth is the error handler.
exported_functions: List[Tuple[Any, ...]] = [ exported_functions = [
# Generic # Generic
('opus_get_version_string', ('opus_get_version_string',
None, ctypes.c_char_p, None), None, ctypes.c_char_p, None),
@@ -179,7 +158,7 @@ exported_functions: List[Tuple[Any, ...]] = [
[ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt),
] ]
def libopus_loader(name: str) -> Any: def libopus_loader(name):
# create the library... # create the library...
lib = ctypes.cdll.LoadLibrary(name) lib = ctypes.cdll.LoadLibrary(name)
@@ -199,11 +178,11 @@ def libopus_loader(name: str) -> Any:
if item[3]: if item[3]:
func.errcheck = item[3] func.errcheck = item[3]
except KeyError: except KeyError:
_log.exception("Error assigning check function to %s", func) log.exception("Error assigning check function to %s", func)
return lib return lib
def _load_default() -> bool: def _load_default():
global _lib global _lib
try: try:
if sys.platform == 'win32': if sys.platform == 'win32':
@@ -219,7 +198,7 @@ def _load_default() -> bool:
return _lib is not None return _lib is not None
def load_opus(name: str) -> None: def load_opus(name):
"""Loads the libopus shared library for use with voice. """Loads the libopus shared library for use with voice.
If this function is not called then the library uses the function If this function is not called then the library uses the function
@@ -257,7 +236,7 @@ def load_opus(name: str) -> None:
global _lib global _lib
_lib = libopus_loader(name) _lib = libopus_loader(name)
def is_loaded() -> bool: def is_loaded():
"""Function to check if opus lib is successfully loaded either """Function to check if opus lib is successfully loaded either
via the :func:`ctypes.util.find_library` call of :func:`load_opus`. via the :func:`ctypes.util.find_library` call of :func:`load_opus`.
@@ -280,10 +259,10 @@ class OpusError(DiscordException):
The error code returned. The error code returned.
""" """
def __init__(self, code: int): def __init__(self, code):
self.code: int = code self.code = code
msg = _lib.opus_strerror(self.code).decode('utf-8') msg = _lib.opus_strerror(self.code).decode('utf-8')
_log.info('"%s" has happened', msg) log.info('"%s" has happened', msg)
super().__init__(msg) super().__init__(msg)
class OpusNotLoaded(DiscordException): class OpusNotLoaded(DiscordException):
@@ -307,96 +286,92 @@ class _OpusStruct:
return _lib.opus_get_version_string().decode('utf-8') return _lib.opus_get_version_string().decode('utf-8')
class Encoder(_OpusStruct): class Encoder(_OpusStruct):
def __init__(self, application: int = APPLICATION_AUDIO): def __init__(self, application=APPLICATION_AUDIO):
_OpusStruct.get_opus_version() _OpusStruct.get_opus_version()
self.application: int = application self.application = application
self._state: EncoderStruct = self._create_state() self._state = self._create_state()
self.set_bitrate(128) self.set_bitrate(128)
self.set_fec(True) self.set_fec(True)
self.set_expected_packet_loss_percent(0.15) self.set_expected_packet_loss_percent(0.15)
self.set_bandwidth('full') self.set_bandwidth('full')
self.set_signal_type('auto') self.set_signal_type('auto')
def __del__(self) -> None: def __del__(self):
if hasattr(self, '_state'): if hasattr(self, '_state'):
_lib.opus_encoder_destroy(self._state) _lib.opus_encoder_destroy(self._state)
# This is a destructor, so it's okay to assign None self._state = None
self._state = None # type: ignore
def _create_state(self) -> EncoderStruct: def _create_state(self):
ret = ctypes.c_int() ret = ctypes.c_int()
return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret)) return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret))
def set_bitrate(self, kbps: int) -> int: def set_bitrate(self, kbps):
kbps = min(512, max(16, int(kbps))) kbps = min(512, max(16, int(kbps)))
_lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024) _lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024)
return kbps return kbps
def set_bandwidth(self, req: BAND_CTL) -> None: def set_bandwidth(self, req):
if req not in band_ctl: if req not in band_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}') raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}')
k = band_ctl[req] k = band_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k) _lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k)
def set_signal_type(self, req: SIGNAL_CTL) -> None: def set_signal_type(self, req):
if req not in signal_ctl: if req not in signal_ctl:
raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}') raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}')
k = signal_ctl[req] k = signal_ctl[req]
_lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k) _lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k)
def set_fec(self, enabled: bool = True) -> None: def set_fec(self, enabled=True):
_lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0)
def set_expected_packet_loss_percent(self, percentage: float) -> None: def set_expected_packet_loss_percent(self, percentage):
_lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100))))
def encode(self, pcm: bytes, frame_size: int) -> bytes: def encode(self, pcm, frame_size):
max_data_bytes = len(pcm) max_data_bytes = len(pcm)
# bytes can be used to reference pointer pcm = ctypes.cast(pcm, c_int16_ptr)
pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore
data = (ctypes.c_char * max_data_bytes)() data = (ctypes.c_char * max_data_bytes)()
ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) ret = _lib.opus_encode(self._state, pcm, frame_size, data, max_data_bytes)
# array can be initialized with bytes but mypy doesn't know return array.array('b', data[:ret]).tobytes()
return array.array('b', data[:ret]).tobytes() # type: ignore
class Decoder(_OpusStruct): class Decoder(_OpusStruct):
def __init__(self): def __init__(self):
_OpusStruct.get_opus_version() _OpusStruct.get_opus_version()
self._state: DecoderStruct = self._create_state() self._state = self._create_state()
def __del__(self) -> None: def __del__(self):
if hasattr(self, '_state'): if hasattr(self, '_state'):
_lib.opus_decoder_destroy(self._state) _lib.opus_decoder_destroy(self._state)
# This is a destructor, so it's okay to assign None self._state = None
self._state = None # type: ignore
def _create_state(self) -> DecoderStruct: def _create_state(self):
ret = ctypes.c_int() ret = ctypes.c_int()
return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret)) return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret))
@staticmethod @staticmethod
def packet_get_nb_frames(data: bytes) -> int: def packet_get_nb_frames(data):
"""Gets the number of frames in an Opus packet""" """Gets the number of frames in an Opus packet"""
return _lib.opus_packet_get_nb_frames(data, len(data)) return _lib.opus_packet_get_nb_frames(data, len(data))
@staticmethod @staticmethod
def packet_get_nb_channels(data: bytes) -> int: def packet_get_nb_channels(data):
"""Gets the number of channels in an Opus packet""" """Gets the number of channels in an Opus packet"""
return _lib.opus_packet_get_nb_channels(data) return _lib.opus_packet_get_nb_channels(data)
@classmethod @classmethod
def packet_get_samples_per_frame(cls, data: bytes) -> int: def packet_get_samples_per_frame(cls, data):
"""Gets the number of samples per frame from an Opus packet""" """Gets the number of samples per frame from an Opus packet"""
return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE) return _lib.opus_packet_get_samples_per_frame(data, cls.SAMPLING_RATE)
def _set_gain(self, adjustment: int) -> int: def _set_gain(self, adjustment):
"""Configures decoder gain adjustment. """Configures decoder gain adjustment.
Scales the decoded output by a factor specified in Q8 dB units. Scales the decoded output by a factor specified in Q8 dB units.
@@ -408,34 +383,26 @@ class Decoder(_OpusStruct):
""" """
return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment) return _lib.opus_decoder_ctl(self._state, CTL_SET_GAIN, adjustment)
def set_gain(self, dB: float) -> int: def set_gain(self, dB):
"""Sets the decoder gain in dB, from -128 to 128.""" """Sets the decoder gain in dB, from -128 to 128."""
dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8)
return self._set_gain(dB_Q8) return self._set_gain(dB_Q8)
def set_volume(self, mult: float) -> int: def set_volume(self, mult):
"""Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc.""" """Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc."""
return self.set_gain(20 * math.log10(mult)) # amplitude ratio return self.set_gain(20 * math.log10(mult)) # amplitude ratio
def _get_last_packet_duration(self) -> int: def _get_last_packet_duration(self):
"""Gets the duration (in samples) of the last packet successfully decoded or concealed.""" """Gets the duration (in samples) of the last packet successfully decoded or concealed."""
ret = ctypes.c_int32() ret = ctypes.c_int32()
_lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret)) _lib.opus_decoder_ctl(self._state, CTL_LAST_PACKET_DURATION, ctypes.byref(ret))
return ret.value return ret.value
@overload def decode(self, data, *, fec=False):
def decode(self, data: bytes, *, fec: bool) -> bytes:
...
@overload
def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes:
...
def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes:
if data is None and fec: if data is None and fec:
raise InvalidArgument("Invalid arguments: FEC cannot be used with null data") raise OpusError("Invalid arguments: FEC cannot be used with null data")
if data is None: if data is None:
frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME

View File

@@ -22,37 +22,17 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from .asset import Asset
from typing import Any, Dict, Optional, TYPE_CHECKING, Type, TypeVar, Union
import re
from .asset import Asset, AssetMixin
from .errors import InvalidArgument
from . import utils from . import utils
__all__ = ( __all__ = (
'PartialEmoji', 'PartialEmoji',
) )
if TYPE_CHECKING:
from .state import ConnectionState
from datetime import datetime
from .types.message import PartialEmoji as PartialEmojiPayload
class _EmojiTag: class _EmojiTag:
__slots__ = () __slots__ = ()
id: int class PartialEmoji(_EmojiTag):
def _to_partial(self) -> PartialEmoji:
raise NotImplementedError
PE = TypeVar('PE', bound='PartialEmoji')
class PartialEmoji(_EmojiTag, AssetMixin):
"""Represents a "partial" emoji. """Represents a "partial" emoji.
This model will be given in two scenarios: This model will be given in two scenarios:
@@ -92,80 +72,35 @@ class PartialEmoji(_EmojiTag, AssetMixin):
__slots__ = ('animated', 'name', 'id', '_state') __slots__ = ('animated', 'name', 'id', '_state')
_CUSTOM_EMOJI_RE = re.compile(r'<?(?P<animated>a)?:?(?P<name>[A-Za-z0-9\_]+):(?P<id>[0-9]{13,20})>?') def __init__(self, *, name, animated=False, id=None):
if TYPE_CHECKING:
id: Optional[int]
def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = None):
self.animated = animated self.animated = animated
self.name = name self.name = name
self.id = id self.id = id
self._state: Optional[ConnectionState] = None self._state = None
@classmethod @classmethod
def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE: def from_dict(cls, data):
return cls( return cls(
animated=data.get('animated', False), animated=data.get('animated', False),
id=utils._get_as_snowflake(data, 'id'), id=utils._get_as_snowflake(data, 'id'),
name=data.get('name') or '', name=data.get('name'),
) )
@classmethod def to_dict(self):
def from_str(cls: Type[PE], value: str) -> PE: o = { 'name': self.name }
"""Converts a Discord string representation of an emoji to a :class:`PartialEmoji`.
The formats accepted are:
- ``a:name:id``
- ``<a:name:id>``
- ``name:id``
- ``<:name:id>``
If the format does not match then it is assumed to be a unicode emoji.
.. versionadded:: 2.0
Parameters
------------
value: :class:`str`
The string representation of an emoji.
Returns
--------
:class:`PartialEmoji`
The partial emoji from this string.
"""
match = cls._CUSTOM_EMOJI_RE.match(value)
if match is not None:
groups = match.groupdict()
animated = bool(groups['animated'])
emoji_id = int(groups['id'])
name = groups['name']
return cls(name=name, animated=animated, id=emoji_id)
return cls(name=value, id=None, animated=False)
def to_dict(self) -> Dict[str, Any]:
o: Dict[str, Any] = {'name': self.name}
if self.id: if self.id:
o['id'] = self.id o['id'] = self.id
if self.animated: if self.animated:
o['animated'] = self.animated o['animated'] = self.animated
return o return o
def _to_partial(self) -> PartialEmoji:
return self
@classmethod @classmethod
def with_state( def with_state(cls, state, *, name, animated=False, id=None):
cls: Type[PE], state: ConnectionState, *, name: str, animated: bool = False, id: Optional[int] = None
) -> PE:
self = cls(name=name, animated=animated, id=id) self = cls(name=name, animated=animated, id=id)
self._state = state self._state = state
return self return self
def __str__(self) -> str: def __str__(self):
if self.id is None: if self.id is None:
return self.name return self.name
if self.animated: if self.animated:
@@ -173,9 +108,9 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return f'<:{self.name}:{self.id}>' return f'<:{self.name}:{self.id}>'
def __repr__(self): def __repr__(self):
return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>' return '<{0.__class__.__name__} animated={0.animated} name={0.name!r} id={0.id}>'.format(self)
def __eq__(self, other: Any) -> bool: def __eq__(self, other):
if self.is_unicode_emoji(): if self.is_unicode_emoji():
return isinstance(other, PartialEmoji) and self.name == other.name return isinstance(other, PartialEmoji) and self.name == other.name
@@ -183,50 +118,75 @@ class PartialEmoji(_EmojiTag, AssetMixin):
return self.id == other.id return self.id == other.id
return False return False
def __ne__(self, other: Any) -> bool: def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self):
return hash((self.id, self.name)) return hash((self.id, self.name))
def is_custom_emoji(self) -> bool: def is_custom_emoji(self):
""":class:`bool`: Checks if this is a custom non-Unicode emoji.""" """:class:`bool`: Checks if this is a custom non-Unicode emoji."""
return self.id is not None return self.id is not None
def is_unicode_emoji(self) -> bool: def is_unicode_emoji(self):
""":class:`bool`: Checks if this is a Unicode emoji.""" """:class:`bool`: Checks if this is a Unicode emoji."""
return self.id is None return self.id is None
def _as_reaction(self) -> str: def _as_reaction(self):
if self.id is None: if self.id is None:
return self.name return self.name
return f'{self.name}:{self.id}' return f'{self.name}:{self.id}'
@property @property
def created_at(self) -> Optional[datetime]: def created_at(self):
"""Optional[:class:`datetime.datetime`]: Returns the emoji's creation time in UTC, or None if Unicode emoji. """Optional[:class:`datetime.datetime`]: Returns the emoji's creation time in UTC, or None if Unicode emoji.
.. versionadded:: 1.6 .. versionadded:: 1.6
""" """
if self.id is None: if self.is_unicode_emoji():
return None return None
return utils.snowflake_time(self.id) return utils.snowflake_time(self.id)
@property @property
def url(self) -> str: def url(self):
""":class:`str`: Returns the URL of the emoji, if it is custom. """:class:`Asset`: Returns the asset of the emoji, if it is custom.
If this isn't a custom emoji then an empty string is returned This is equivalent to calling :meth:`url_as` with
the default parameters (i.e. png/gif detection).
"""
return self.url_as(format=None)
def url_as(self, *, format=None, static_format="png"):
"""Returns an :class:`Asset` for the emoji's url, if it is custom.
The format must be one of 'webp', 'jpeg', 'jpg', 'png' or 'gif'.
'gif' is only valid for animated emojis.
.. versionadded:: 1.7
Parameters
-----------
format: Optional[:class:`str`]
The format to attempt to convert the emojis to.
If the format is ``None``, then it is automatically
detected as either 'gif' or static_format, depending on whether the
emoji is animated or not.
static_format: Optional[:class:`str`]
Format to attempt to convert only non-animated emoji's to.
Defaults to 'png'
Raises
-------
InvalidArgument
Bad image format passed to ``format`` or ``static_format``.
Returns
--------
:class:`Asset`
The resulting CDN asset.
""" """
if self.is_unicode_emoji(): if self.is_unicode_emoji():
return '' return Asset(self._state)
fmt = 'gif' if self.animated else 'png' return Asset._from_emoji(self._state, self, format=format, static_format=static_format)
return f'{Asset.BASE}/emojis/{self.id}.{fmt}'
async def read(self) -> bytes:
if self.is_unicode_emoji():
raise InvalidArgument('PartialEmoji is not a custom emoji')
return await super().read()

View File

@@ -22,9 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Callable, Any, ClassVar, Dict, Iterator, Set, TYPE_CHECKING, Tuple, Type, TypeVar, Optional
from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value
__all__ = ( __all__ = (
@@ -35,19 +32,15 @@ __all__ = (
# A permission alias works like a regular flag but is marked # A permission alias works like a regular flag but is marked
# So the PermissionOverwrite knows to work with it # So the PermissionOverwrite knows to work with it
class permission_alias(alias_flag_value): class permission_alias(alias_flag_value):
alias: str pass
def make_permission_alias(alias):
def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permission_alias]: def decorator(func):
def decorator(func: Callable[[Any], int]) -> permission_alias:
ret = permission_alias(func) ret = permission_alias(func)
ret.alias = alias ret.alias = alias
return ret return ret
return decorator return decorator
P = TypeVar('P', bound='Permissions')
@fill_with_flags() @fill_with_flags()
class Permissions(BaseFlags): class Permissions(BaseFlags):
"""Wraps up the Discord permission value. """Wraps up the Discord permission value.
@@ -99,7 +92,7 @@ class Permissions(BaseFlags):
__slots__ = () __slots__ = ()
def __init__(self, permissions: int = 0, **kwargs: bool): def __init__(self, permissions=0, **kwargs):
if not isinstance(permissions, int): if not isinstance(permissions, int):
raise TypeError(f'Expected int parameter, received {permissions.__class__.__name__} instead.') raise TypeError(f'Expected int parameter, received {permissions.__class__.__name__} instead.')
@@ -109,25 +102,25 @@ class Permissions(BaseFlags):
raise TypeError(f'{key!r} is not a valid permission name.') raise TypeError(f'{key!r} is not a valid permission name.')
setattr(self, key, value) setattr(self, key, value)
def is_subset(self, other: Permissions) -> bool: def is_subset(self, other):
"""Returns ``True`` if self has the same or fewer permissions as other.""" """Returns ``True`` if self has the same or fewer permissions as other."""
if isinstance(other, Permissions): if isinstance(other, Permissions):
return (self.value & other.value) == self.value return (self.value & other.value) == self.value
else: else:
raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}") raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}")
def is_superset(self, other: Permissions) -> bool: def is_superset(self, other):
"""Returns ``True`` if self has the same or more permissions as other.""" """Returns ``True`` if self has the same or more permissions as other."""
if isinstance(other, Permissions): if isinstance(other, Permissions):
return (self.value | other.value) == self.value return (self.value | other.value) == self.value
else: else:
raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}") raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}")
def is_strict_subset(self, other: Permissions) -> bool: def is_strict_subset(self, other):
"""Returns ``True`` if the permissions on other are a strict subset of those on self.""" """Returns ``True`` if the permissions on other are a strict subset of those on self."""
return self.is_subset(other) and self != other return self.is_subset(other) and self != other
def is_strict_superset(self, other: Permissions) -> bool: def is_strict_superset(self, other):
"""Returns ``True`` if the permissions on other are a strict superset of those on self.""" """Returns ``True`` if the permissions on other are a strict superset of those on self."""
return self.is_superset(other) and self != other return self.is_superset(other) and self != other
@@ -137,20 +130,20 @@ class Permissions(BaseFlags):
__gt__ = is_strict_superset __gt__ = is_strict_superset
@classmethod @classmethod
def none(cls: Type[P]) -> P: def none(cls):
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all
permissions set to ``False``.""" permissions set to ``False``."""
return cls(0) return cls(0)
@classmethod @classmethod
def all(cls: Type[P]) -> P: def all(cls):
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all
permissions set to ``True``. permissions set to ``True``.
""" """
return cls(0b111111111111111111111111111111111111111) return cls(0b111111111111111111111111111111111)
@classmethod @classmethod
def all_channel(cls: Type[P]) -> P: def all_channel(cls):
"""A :class:`Permissions` with all channel-specific permissions set to """A :class:`Permissions` with all channel-specific permissions set to
``True`` and the guild-specific ones set to ``False``. The guild-specific ``True`` and the guild-specific ones set to ``False``. The guild-specific
permissions are currently: permissions are currently:
@@ -167,16 +160,11 @@ class Permissions(BaseFlags):
.. versionchanged:: 1.7 .. versionchanged:: 1.7
Added :attr:`stream`, :attr:`priority_speaker` and :attr:`use_slash_commands` permissions. Added :attr:`stream`, :attr:`priority_speaker` and :attr:`use_slash_commands` permissions.
.. versionchanged:: 2.0
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
:attr:`use_external_stickers`, :attr:`send_messages_in_threads` and
:attr:`request_to_speak` permissions.
""" """
return cls(0b111110110110011111101111111111101010001) return cls(0b10110011111101111111111101010001)
@classmethod @classmethod
def general(cls: Type[P]) -> P: def general(cls):
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all
"General" permissions from the official Discord UI set to ``True``. "General" permissions from the official Discord UI set to ``True``.
@@ -189,7 +177,7 @@ class Permissions(BaseFlags):
return cls(0b01110000000010000000010010110000) return cls(0b01110000000010000000010010110000)
@classmethod @classmethod
def membership(cls: Type[P]) -> P: def membership(cls):
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all
"Membership" permissions from the official Discord UI set to ``True``. "Membership" permissions from the official Discord UI set to ``True``.
@@ -198,28 +186,24 @@ class Permissions(BaseFlags):
return cls(0b00001100000000000000000000000111) return cls(0b00001100000000000000000000000111)
@classmethod @classmethod
def text(cls: Type[P]) -> P: def text(cls):
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all
"Text" permissions from the official Discord UI set to ``True``. "Text" permissions from the official Discord UI set to ``True``.
.. versionchanged:: 1.7 .. versionchanged:: 1.7
Permission :attr:`read_messages` is no longer part of the text permissions. Permission :attr:`read_messages` is no longer part of the text permissions.
Added :attr:`use_slash_commands` permission. Added :attr:`use_slash_commands` permission.
.. versionchanged:: 2.0
Added :attr:`create_public_threads`, :attr:`create_private_threads`, :attr:`manage_threads`,
:attr:`send_messages_in_threads` and :attr:`use_external_stickers` permissions.
""" """
return cls(0b111110010000000000001111111100001000000) return cls(0b10000000000001111111100001000000)
@classmethod @classmethod
def voice(cls: Type[P]) -> P: def voice(cls):
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all
"Voice" permissions from the official Discord UI set to ``True``.""" "Voice" permissions from the official Discord UI set to ``True``."""
return cls(0b00000011111100000000001100000000) return cls(0b00000011111100000000001100000000)
@classmethod @classmethod
def stage(cls: Type[P]) -> P: def stage(cls):
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all
"Stage Channel" permissions from the official Discord UI set to ``True``. "Stage Channel" permissions from the official Discord UI set to ``True``.
@@ -228,7 +212,7 @@ class Permissions(BaseFlags):
return cls(1 << 32) return cls(1 << 32)
@classmethod @classmethod
def stage_moderator(cls: Type[P]) -> P: def stage_moderator(cls):
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all
"Stage Moderator" permissions from the official Discord UI set to ``True``. "Stage Moderator" permissions from the official Discord UI set to ``True``.
@@ -237,7 +221,7 @@ class Permissions(BaseFlags):
return cls(0b100000001010000000000000000000000) return cls(0b100000001010000000000000000000000)
@classmethod @classmethod
def advanced(cls: Type[P]) -> P: def advanced(cls):
"""A factory method that creates a :class:`Permissions` with all """A factory method that creates a :class:`Permissions` with all
"Advanced" permissions from the official Discord UI set to ``True``. "Advanced" permissions from the official Discord UI set to ``True``.
@@ -245,7 +229,7 @@ class Permissions(BaseFlags):
""" """
return cls(1 << 3) return cls(1 << 3)
def update(self, **kwargs: bool) -> None: def update(self, **kwargs):
r"""Bulk updates this permission object. r"""Bulk updates this permission object.
Allows you to set multiple attributes by using keyword Allows you to set multiple attributes by using keyword
@@ -261,7 +245,7 @@ class Permissions(BaseFlags):
if key in self.VALID_FLAGS: if key in self.VALID_FLAGS:
setattr(self, key, value) setattr(self, key, value)
def handle_overwrite(self, allow: int, deny: int) -> None: def handle_overwrite(self, allow, deny):
# Basically this is what's happening here. # Basically this is what's happening here.
# We have an original bit array, e.g. 1010 # We have an original bit array, e.g. 1010
# Then we have another bit array that is 'denied', e.g. 1111 # Then we have another bit array that is 'denied', e.g. 1111
@@ -277,67 +261,69 @@ class Permissions(BaseFlags):
self.value = (self.value & ~deny) | allow self.value = (self.value & ~deny) | allow
@flag_value @flag_value
def create_instant_invite(self) -> int: def create_instant_invite(self):
""":class:`bool`: Returns ``True`` if the user can create instant invites.""" """:class:`bool`: Returns ``True`` if the user can create instant invites."""
return 1 << 0 return 1 << 0
@flag_value @flag_value
def kick_members(self) -> int: def kick_members(self):
""":class:`bool`: Returns ``True`` if the user can kick users from the guild.""" """:class:`bool`: Returns ``True`` if the user can kick users from the guild."""
return 1 << 1 return 1 << 1
@flag_value @flag_value
def ban_members(self) -> int: def ban_members(self):
""":class:`bool`: Returns ``True`` if a user can ban users from the guild.""" """:class:`bool`: Returns ``True`` if a user can ban users from the guild."""
return 1 << 2 return 1 << 2
@flag_value @flag_value
def administrator(self) -> int: def administrator(self):
""":class:`bool`: Returns ``True`` if a user is an administrator. This role overrides all other permissions. """:class:`bool`: Returns ``True`` if a user is an administrator. This role overrides all other permissions.
This also bypasses all channel-specific overrides. This also bypasses all channel-specific overrides.
""" """
return 1 << 3 return 1 << 3
admin = administrator
@flag_value @flag_value
def manage_channels(self) -> int: def manage_channels(self):
""":class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild. """:class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild.
This also corresponds to the "Manage Channel" channel-specific override.""" This also corresponds to the "Manage Channel" channel-specific override."""
return 1 << 4 return 1 << 4
@flag_value @flag_value
def manage_guild(self) -> int: def manage_guild(self):
""":class:`bool`: Returns ``True`` if a user can edit guild properties.""" """:class:`bool`: Returns ``True`` if a user can edit guild properties."""
return 1 << 5 return 1 << 5
@flag_value @flag_value
def add_reactions(self) -> int: def add_reactions(self):
""":class:`bool`: Returns ``True`` if a user can add reactions to messages.""" """:class:`bool`: Returns ``True`` if a user can add reactions to messages."""
return 1 << 6 return 1 << 6
@flag_value @flag_value
def view_audit_log(self) -> int: def view_audit_log(self):
""":class:`bool`: Returns ``True`` if a user can view the guild's audit log.""" """:class:`bool`: Returns ``True`` if a user can view the guild's audit log."""
return 1 << 7 return 1 << 7
@flag_value @flag_value
def priority_speaker(self) -> int: def priority_speaker(self):
""":class:`bool`: Returns ``True`` if a user can be more easily heard while talking.""" """:class:`bool`: Returns ``True`` if a user can be more easily heard while talking."""
return 1 << 8 return 1 << 8
@flag_value @flag_value
def stream(self) -> int: def stream(self):
""":class:`bool`: Returns ``True`` if a user can stream in a voice channel.""" """:class:`bool`: Returns ``True`` if a user can stream in a voice channel."""
return 1 << 9 return 1 << 9
@flag_value @flag_value
def read_messages(self) -> int: def read_messages(self):
""":class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels.""" """:class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels."""
return 1 << 10 return 1 << 10
@make_permission_alias('read_messages') @make_permission_alias('read_messages')
def view_channel(self) -> int: def view_channel(self):
""":class:`bool`: An alias for :attr:`read_messages`. """:class:`bool`: An alias for :attr:`read_messages`.
.. versionadded:: 1.3 .. versionadded:: 1.3
@@ -345,17 +331,17 @@ class Permissions(BaseFlags):
return 1 << 10 return 1 << 10
@flag_value @flag_value
def send_messages(self) -> int: def send_messages(self):
""":class:`bool`: Returns ``True`` if a user can send messages from all or specific text channels.""" """:class:`bool`: Returns ``True`` if a user can send messages from all or specific text channels."""
return 1 << 11 return 1 << 11
@flag_value @flag_value
def send_tts_messages(self) -> int: def send_tts_messages(self):
""":class:`bool`: Returns ``True`` if a user can send TTS messages from all or specific text channels.""" """:class:`bool`: Returns ``True`` if a user can send TTS messages from all or specific text channels."""
return 1 << 12 return 1 << 12
@flag_value @flag_value
def manage_messages(self) -> int: def manage_messages(self):
""":class:`bool`: Returns ``True`` if a user can delete or pin messages in a text channel. """:class:`bool`: Returns ``True`` if a user can delete or pin messages in a text channel.
.. note:: .. note::
@@ -365,32 +351,32 @@ class Permissions(BaseFlags):
return 1 << 13 return 1 << 13
@flag_value @flag_value
def embed_links(self) -> int: def embed_links(self):
""":class:`bool`: Returns ``True`` if a user's messages will automatically be embedded by Discord.""" """:class:`bool`: Returns ``True`` if a user's messages will automatically be embedded by Discord."""
return 1 << 14 return 1 << 14
@flag_value @flag_value
def attach_files(self) -> int: def attach_files(self):
""":class:`bool`: Returns ``True`` if a user can send files in their messages.""" """:class:`bool`: Returns ``True`` if a user can send files in their messages."""
return 1 << 15 return 1 << 15
@flag_value @flag_value
def read_message_history(self) -> int: def read_message_history(self):
""":class:`bool`: Returns ``True`` if a user can read a text channel's previous messages.""" """:class:`bool`: Returns ``True`` if a user can read a text channel's previous messages."""
return 1 << 16 return 1 << 16
@flag_value @flag_value
def mention_everyone(self) -> int: def mention_everyone(self):
""":class:`bool`: Returns ``True`` if a user's @everyone or @here will mention everyone in the text channel.""" """:class:`bool`: Returns ``True`` if a user's @everyone or @here will mention everyone in the text channel."""
return 1 << 17 return 1 << 17
@flag_value @flag_value
def external_emojis(self) -> int: def external_emojis(self):
""":class:`bool`: Returns ``True`` if a user can use emojis from other guilds.""" """:class:`bool`: Returns ``True`` if a user can use emojis from other guilds."""
return 1 << 18 return 1 << 18
@make_permission_alias('external_emojis') @make_permission_alias('external_emojis')
def use_external_emojis(self) -> int: def use_external_emojis(self):
""":class:`bool`: An alias for :attr:`external_emojis`. """:class:`bool`: An alias for :attr:`external_emojis`.
.. versionadded:: 1.3 .. versionadded:: 1.3
@@ -398,7 +384,7 @@ class Permissions(BaseFlags):
return 1 << 18 return 1 << 18
@flag_value @flag_value
def view_guild_insights(self) -> int: def view_guild_insights(self):
""":class:`bool`: Returns ``True`` if a user can view the guild's insights. """:class:`bool`: Returns ``True`` if a user can view the guild's insights.
.. versionadded:: 1.3 .. versionadded:: 1.3
@@ -406,47 +392,47 @@ class Permissions(BaseFlags):
return 1 << 19 return 1 << 19
@flag_value @flag_value
def connect(self) -> int: def connect(self):
""":class:`bool`: Returns ``True`` if a user can connect to a voice channel.""" """:class:`bool`: Returns ``True`` if a user can connect to a voice channel."""
return 1 << 20 return 1 << 20
@flag_value @flag_value
def speak(self) -> int: def speak(self):
""":class:`bool`: Returns ``True`` if a user can speak in a voice channel.""" """:class:`bool`: Returns ``True`` if a user can speak in a voice channel."""
return 1 << 21 return 1 << 21
@flag_value @flag_value
def mute_members(self) -> int: def mute_members(self):
""":class:`bool`: Returns ``True`` if a user can mute other users.""" """:class:`bool`: Returns ``True`` if a user can mute other users."""
return 1 << 22 return 1 << 22
@flag_value @flag_value
def deafen_members(self) -> int: def deafen_members(self):
""":class:`bool`: Returns ``True`` if a user can deafen other users.""" """:class:`bool`: Returns ``True`` if a user can deafen other users."""
return 1 << 23 return 1 << 23
@flag_value @flag_value
def move_members(self) -> int: def move_members(self):
""":class:`bool`: Returns ``True`` if a user can move users between other voice channels.""" """:class:`bool`: Returns ``True`` if a user can move users between other voice channels."""
return 1 << 24 return 1 << 24
@flag_value @flag_value
def use_voice_activation(self) -> int: def use_voice_activation(self):
""":class:`bool`: Returns ``True`` if a user can use voice activation in voice channels.""" """:class:`bool`: Returns ``True`` if a user can use voice activation in voice channels."""
return 1 << 25 return 1 << 25
@flag_value @flag_value
def change_nickname(self) -> int: def change_nickname(self):
""":class:`bool`: Returns ``True`` if a user can change their nickname in the guild.""" """:class:`bool`: Returns ``True`` if a user can change their nickname in the guild."""
return 1 << 26 return 1 << 26
@flag_value @flag_value
def manage_nicknames(self) -> int: def manage_nicknames(self):
""":class:`bool`: Returns ``True`` if a user can change other user's nickname in the guild.""" """:class:`bool`: Returns ``True`` if a user can change other user's nickname in the guild."""
return 1 << 27 return 1 << 27
@flag_value @flag_value
def manage_roles(self) -> int: def manage_roles(self):
""":class:`bool`: Returns ``True`` if a user can create or edit roles less than their role's position. """:class:`bool`: Returns ``True`` if a user can create or edit roles less than their role's position.
This also corresponds to the "Manage Permissions" channel-specific override. This also corresponds to the "Manage Permissions" channel-specific override.
@@ -454,7 +440,7 @@ class Permissions(BaseFlags):
return 1 << 28 return 1 << 28
@make_permission_alias('manage_roles') @make_permission_alias('manage_roles')
def manage_permissions(self) -> int: def manage_permissions(self):
""":class:`bool`: An alias for :attr:`manage_roles`. """:class:`bool`: An alias for :attr:`manage_roles`.
.. versionadded:: 1.3 .. versionadded:: 1.3
@@ -462,25 +448,17 @@ class Permissions(BaseFlags):
return 1 << 28 return 1 << 28
@flag_value @flag_value
def manage_webhooks(self) -> int: def manage_webhooks(self):
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete webhooks.""" """:class:`bool`: Returns ``True`` if a user can create, edit, or delete webhooks."""
return 1 << 29 return 1 << 29
@flag_value @flag_value
def manage_emojis(self) -> int: def manage_emojis(self):
""":class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis.""" """:class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis."""
return 1 << 30 return 1 << 30
@make_permission_alias('manage_emojis')
def manage_emojis_and_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`manage_emojis`.
.. versionadded:: 2.0
"""
return 1 << 30
@flag_value @flag_value
def use_slash_commands(self) -> int: def use_slash_commands(self):
""":class:`bool`: Returns ``True`` if a user can use slash commands. """:class:`bool`: Returns ``True`` if a user can use slash commands.
.. versionadded:: 1.7 .. versionadded:: 1.7
@@ -488,72 +466,14 @@ class Permissions(BaseFlags):
return 1 << 31 return 1 << 31
@flag_value @flag_value
def request_to_speak(self) -> int: def request_to_speak(self):
""":class:`bool`: Returns ``True`` if a user can request to speak in a stage channel. """:class:`bool`: Returns ``True`` if a user can request to speak in a stage channel.
.. versionadded:: 1.7 .. versionadded:: 1.7
""" """
return 1 << 32 return 1 << 32
@flag_value def augment_from_permissions(cls):
def manage_events(self) -> int:
""":class:`bool`: Returns ``True`` if a user can manage guild events.
.. versionadded:: 2.0
"""
return 1 << 33
@flag_value
def manage_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can manage threads.
.. versionadded:: 2.0
"""
return 1 << 34
@flag_value
def create_public_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create public threads.
.. versionadded:: 2.0
"""
return 1 << 35
@flag_value
def create_private_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can create private threads.
.. versionadded:: 2.0
"""
return 1 << 36
@flag_value
def external_stickers(self) -> int:
""":class:`bool`: Returns ``True`` if a user can use stickers from other guilds.
.. versionadded:: 2.0
"""
return 1 << 37
@make_permission_alias('external_stickers')
def use_external_stickers(self) -> int:
""":class:`bool`: An alias for :attr:`external_stickers`.
.. versionadded:: 2.0
"""
return 1 << 37
@flag_value
def send_messages_in_threads(self) -> int:
""":class:`bool`: Returns ``True`` if a user can send messages in threads.
.. versionadded:: 2.0
"""
return 1 << 38
PO = TypeVar('PO', bound='PermissionOverwrite')
def _augment_from_permissions(cls):
cls.VALID_NAMES = set(Permissions.VALID_FLAGS) cls.VALID_NAMES = set(Permissions.VALID_FLAGS)
aliases = set() aliases = set()
@@ -570,7 +490,6 @@ def _augment_from_permissions(cls):
# god bless Python # god bless Python
def getter(self, x=key): def getter(self, x=key):
return self._values.get(x) return self._values.get(x)
def setter(self, value, x=key): def setter(self, value, x=key):
self._set(x, value) self._set(x, value)
@@ -580,8 +499,7 @@ def _augment_from_permissions(cls):
cls.PURE_FLAGS = cls.VALID_NAMES - aliases cls.PURE_FLAGS = cls.VALID_NAMES - aliases
return cls return cls
@augment_from_permissions
@_augment_from_permissions
class PermissionOverwrite: class PermissionOverwrite:
r"""A type that is used to represent a channel specific permission. r"""A type that is used to represent a channel specific permission.
@@ -616,57 +534,8 @@ class PermissionOverwrite:
__slots__ = ('_values',) __slots__ = ('_values',)
if TYPE_CHECKING: def __init__(self, **kwargs):
VALID_NAMES: ClassVar[Set[str]] self._values = {}
PURE_FLAGS: ClassVar[Set[str]]
# I wish I didn't have to do this
create_instant_invite: Optional[bool]
kick_members: Optional[bool]
ban_members: Optional[bool]
administrator: Optional[bool]
manage_channels: Optional[bool]
manage_guild: Optional[bool]
add_reactions: Optional[bool]
view_audit_log: Optional[bool]
priority_speaker: Optional[bool]
stream: Optional[bool]
read_messages: Optional[bool]
view_channel: Optional[bool]
send_messages: Optional[bool]
send_tts_messages: Optional[bool]
manage_messages: Optional[bool]
embed_links: Optional[bool]
attach_files: Optional[bool]
read_message_history: Optional[bool]
mention_everyone: Optional[bool]
external_emojis: Optional[bool]
use_external_emojis: Optional[bool]
view_guild_insights: Optional[bool]
connect: Optional[bool]
speak: Optional[bool]
mute_members: Optional[bool]
deafen_members: Optional[bool]
move_members: Optional[bool]
use_voice_activation: Optional[bool]
change_nickname: Optional[bool]
manage_nicknames: Optional[bool]
manage_roles: Optional[bool]
manage_permissions: Optional[bool]
manage_webhooks: Optional[bool]
manage_emojis: Optional[bool]
manage_emojis_and_stickers: Optional[bool]
use_slash_commands: Optional[bool]
request_to_speak: Optional[bool]
manage_events: Optional[bool]
manage_threads: Optional[bool]
create_public_threads: Optional[bool]
create_private_threads: Optional[bool]
send_messages_in_threads: Optional[bool]
external_stickers: Optional[bool]
use_external_stickers: Optional[bool]
def __init__(self, **kwargs: Optional[bool]):
self._values: Dict[str, Optional[bool]] = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if key not in self.VALID_NAMES: if key not in self.VALID_NAMES:
@@ -674,10 +543,10 @@ class PermissionOverwrite:
setattr(self, key, value) setattr(self, key, value)
def __eq__(self, other: Any) -> bool: def __eq__(self, other):
return isinstance(other, PermissionOverwrite) and self._values == other._values return isinstance(other, PermissionOverwrite) and self._values == other._values
def _set(self, key: str, value: Optional[bool]) -> None: def _set(self, key, value):
if value not in (True, None, False): if value not in (True, None, False):
raise TypeError(f'Expected bool or NoneType, received {value.__class__.__name__}') raise TypeError(f'Expected bool or NoneType, received {value.__class__.__name__}')
@@ -686,7 +555,7 @@ class PermissionOverwrite:
else: else:
self._values[key] = value self._values[key] = value
def pair(self) -> Tuple[Permissions, Permissions]: def pair(self):
"""Tuple[:class:`Permissions`, :class:`Permissions`]: Returns the (allow, deny) pair from this overwrite.""" """Tuple[:class:`Permissions`, :class:`Permissions`]: Returns the (allow, deny) pair from this overwrite."""
allow = Permissions.none() allow = Permissions.none()
@@ -701,7 +570,7 @@ class PermissionOverwrite:
return allow, deny return allow, deny
@classmethod @classmethod
def from_pair(cls: Type[PO], allow: Permissions, deny: Permissions) -> PO: def from_pair(cls, allow, deny):
"""Creates an overwrite from an allow/deny pair of :class:`Permissions`.""" """Creates an overwrite from an allow/deny pair of :class:`Permissions`."""
ret = cls() ret = cls()
for key, value in allow: for key, value in allow:
@@ -714,7 +583,7 @@ class PermissionOverwrite:
return ret return ret
def is_empty(self) -> bool: def is_empty(self):
"""Checks if the permission overwrite is currently empty. """Checks if the permission overwrite is currently empty.
An empty permission overwrite is one that has no overwrites set An empty permission overwrite is one that has no overwrites set
@@ -727,7 +596,7 @@ class PermissionOverwrite:
""" """
return len(self._values) == 0 return len(self._values) == 0
def update(self, **kwargs: bool) -> None: def update(self, **kwargs):
r"""Bulk updates this permission overwrite object. r"""Bulk updates this permission overwrite object.
Allows you to set multiple attributes by using keyword Allows you to set multiple attributes by using keyword
@@ -745,6 +614,6 @@ class PermissionOverwrite:
setattr(self, key, value) setattr(self, key, value)
def __iter__(self) -> Iterator[Tuple[str, Optional[bool]]]: def __iter__(self):
for key in self.PURE_FLAGS: for key in self.PURE_FLAGS:
yield key, self._values.get(key) yield key, self._values.get(key)

View File

@@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import threading import threading
import traceback import traceback
@@ -34,23 +33,12 @@ import time
import json import json
import sys import sys
import re import re
import io
from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
from .errors import ClientException from .errors import ClientException
from .opus import Encoder as OpusEncoder from .opus import Encoder as OpusEncoder
from .oggparse import OggStream from .oggparse import OggStream
from .utils import MISSING
if TYPE_CHECKING: log = logging.getLogger(__name__)
from .voice_client import VoiceClient
AT = TypeVar('AT', bound='AudioSource')
FT = TypeVar('FT', bound='FFmpegOpusAudio')
_log = logging.getLogger(__name__)
__all__ = ( __all__ = (
'AudioSource', 'AudioSource',
@@ -61,8 +49,6 @@ __all__ = (
'PCMVolumeTransformer', 'PCMVolumeTransformer',
) )
CREATE_NO_WINDOW: int
if sys.platform != 'win32': if sys.platform != 'win32':
CREATE_NO_WINDOW = 0 CREATE_NO_WINDOW = 0
else: else:
@@ -79,7 +65,7 @@ class AudioSource:
The audio source reads are done in a separate thread. The audio source reads are done in a separate thread.
""" """
def read(self) -> bytes: def read(self):
"""Reads 20ms worth of audio. """Reads 20ms worth of audio.
Subclasses must implement this. Subclasses must implement this.
@@ -87,7 +73,7 @@ class AudioSource:
If the audio is complete, then returning an empty If the audio is complete, then returning an empty
:term:`py:bytes-like object` to signal this is the way to do so. :term:`py:bytes-like object` to signal this is the way to do so.
If :meth:`~AudioSource.is_opus` method returns ``True``, then it must return If :meth:`is_opus` method returns ``True``, then it must return
20ms worth of Opus encoded audio. Otherwise, it must be 20ms 20ms worth of Opus encoded audio. Otherwise, it must be 20ms
worth of 16-bit 48KHz stereo PCM, which is about 3,840 bytes worth of 16-bit 48KHz stereo PCM, which is about 3,840 bytes
per frame (20ms worth of audio). per frame (20ms worth of audio).
@@ -99,11 +85,11 @@ class AudioSource:
""" """
raise NotImplementedError raise NotImplementedError
def is_opus(self) -> bool: def is_opus(self):
"""Checks if the audio source is already encoded in Opus.""" """Checks if the audio source is already encoded in Opus."""
return False return False
def cleanup(self) -> None: def cleanup(self):
"""Called when clean-up is needed to be done. """Called when clean-up is needed to be done.
Useful for clearing buffer data or processes after Useful for clearing buffer data or processes after
@@ -111,7 +97,7 @@ class AudioSource:
""" """
pass pass
def __del__(self) -> None: def __del__(self):
self.cleanup() self.cleanup()
class PCMAudio(AudioSource): class PCMAudio(AudioSource):
@@ -122,10 +108,10 @@ class PCMAudio(AudioSource):
stream: :term:`py:file object` stream: :term:`py:file object`
A file-like object that reads byte data representing raw PCM. A file-like object that reads byte data representing raw PCM.
""" """
def __init__(self, stream: io.BufferedIOBase) -> None: def __init__(self, stream):
self.stream: io.BufferedIOBase = stream self.stream = stream
def read(self) -> bytes: def read(self):
ret = self.stream.read(OpusEncoder.FRAME_SIZE) ret = self.stream.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE: if len(ret) != OpusEncoder.FRAME_SIZE:
return b'' return b''
@@ -140,27 +126,17 @@ class FFmpegAudio(AudioSource):
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any): def __init__(self, source, *, executable='ffmpeg', args, **subprocess_kwargs):
piping = subprocess_kwargs.get('stdin') == subprocess.PIPE self._process = self._stdout = None
if piping and isinstance(source, str):
raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin")
args = [executable, *args] args = [executable, *args]
kwargs = {'stdout': subprocess.PIPE} kwargs = {'stdout': subprocess.PIPE}
kwargs.update(subprocess_kwargs) kwargs.update(subprocess_kwargs)
self._process: subprocess.Popen = self._spawn_process(args, **kwargs) self._process = self._spawn_process(args, **kwargs)
self._stdout: IO[bytes] = self._process.stdout # type: ignore self._stdout = self._process.stdout
self._stdin: Optional[IO[Bytes]] = None
self._pipe_thread: Optional[threading.Thread] = None
if piping: def _spawn_process(self, args, **subprocess_kwargs):
n = f'popen-stdin-writer:{id(self):#x}'
self._stdin = self._process.stdin
self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n)
self._pipe_thread.start()
def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen:
process = None process = None
try: try:
process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs)
@@ -168,48 +144,30 @@ class FFmpegAudio(AudioSource):
executable = args.partition(' ')[0] if isinstance(args, str) else args[0] executable = args.partition(' ')[0] if isinstance(args, str) else args[0]
raise ClientException(executable + ' was not found.') from None raise ClientException(executable + ' was not found.') from None
except subprocess.SubprocessError as exc: except subprocess.SubprocessError as exc:
raise ClientException(f'Popen failed: {exc.__class__.__name__}: {exc}') from exc raise ClientException('Popen failed: {0.__class__.__name__}: {0}'.format(exc)) from exc
else: else:
return process return process
def _kill_process(self) -> None: def cleanup(self):
proc = self._process proc = self._process
if proc is MISSING: if proc is None:
return return
_log.info('Preparing to terminate ffmpeg process %s.', proc.pid) log.info('Preparing to terminate ffmpeg process %s.', proc.pid)
try: try:
proc.kill() proc.kill()
except Exception: except Exception:
_log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid) log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid)
if proc.poll() is None: if proc.poll() is None:
_log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid) log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid)
proc.communicate() proc.communicate()
_log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode) log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode)
else: else:
_log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode) log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode)
self._process = self._stdout = None
def _pipe_writer(self, source: io.BufferedIOBase) -> None:
while self._process:
# arbitrarily large read size
data = source.read(8192)
if not data:
self._process.terminate()
return
try:
self._stdin.write(data)
except Exception:
_log.debug('Write error for %s, this is probably not a problem', self, exc_info=True)
# at this point the source data is either exhausted or the process is fubar
self._process.terminate()
return
def cleanup(self) -> None:
self._kill_process()
self._process = self._stdout = self._stdin = MISSING
class FFmpegPCMAudio(FFmpegAudio): class FFmpegPCMAudio(FFmpegAudio):
"""An audio source from FFmpeg (or AVConv). """An audio source from FFmpeg (or AVConv).
@@ -246,18 +204,9 @@ class FFmpegPCMAudio(FFmpegAudio):
The subprocess failed to be created. The subprocess failed to be created.
""" """
def __init__( def __init__(self, source, *, executable='ffmpeg', pipe=False, stderr=None, before_options=None, options=None):
self,
source: Union[str, io.BufferedIOBase],
*,
executable: str = 'ffmpeg',
pipe: bool = False,
stderr: Optional[IO[str]] = None,
before_options: Optional[str] = None,
options: Optional[str] = None
) -> None:
args = [] args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
if isinstance(before_options, str): if isinstance(before_options, str):
args.extend(shlex.split(before_options)) args.extend(shlex.split(before_options))
@@ -273,13 +222,13 @@ class FFmpegPCMAudio(FFmpegAudio):
super().__init__(source, executable=executable, args=args, **subprocess_kwargs) super().__init__(source, executable=executable, args=args, **subprocess_kwargs)
def read(self) -> bytes: def read(self):
ret = self._stdout.read(OpusEncoder.FRAME_SIZE) ret = self._stdout.read(OpusEncoder.FRAME_SIZE)
if len(ret) != OpusEncoder.FRAME_SIZE: if len(ret) != OpusEncoder.FRAME_SIZE:
return b'' return b''
return ret return ret
def is_opus(self) -> bool: def is_opus(self):
return False return False
class FFmpegOpusAudio(FFmpegAudio): class FFmpegOpusAudio(FFmpegAudio):
@@ -343,21 +292,11 @@ class FFmpegOpusAudio(FFmpegAudio):
The subprocess failed to be created. The subprocess failed to be created.
""" """
def __init__( def __init__(self, source, *, bitrate=128, codec=None, executable='ffmpeg',
self, pipe=False, stderr=None, before_options=None, options=None):
source: Union[str, io.BufferedIOBase],
*,
bitrate: int = 128,
codec: Optional[str] = None,
executable: str = 'ffmpeg',
pipe=False,
stderr=None,
before_options=None,
options=None,
) -> None:
args = [] args = []
subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} subprocess_kwargs = {'stdin': source if pipe else subprocess.DEVNULL, 'stderr': stderr}
if isinstance(before_options, str): if isinstance(before_options, str):
args.extend(shlex.split(before_options)) args.extend(shlex.split(before_options))
@@ -384,13 +323,7 @@ class FFmpegOpusAudio(FFmpegAudio):
self._packet_iter = OggStream(self._stdout).iter_packets() self._packet_iter = OggStream(self._stdout).iter_packets()
@classmethod @classmethod
async def from_probe( async def from_probe(cls, source, *, method=None, **kwargs):
cls: Type[FT],
source: str,
*,
method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None,
**kwargs: Any,
) -> FT:
"""|coro| """|coro|
A factory method that creates a :class:`FFmpegOpusAudio` after probing A factory method that creates a :class:`FFmpegOpusAudio` after probing
@@ -414,6 +347,7 @@ class FFmpegOpusAudio(FFmpegAudio):
def custom_probe(source, executable): def custom_probe(source, executable):
# some analysis code here # some analysis code here
return codec, bitrate return codec, bitrate
source = await discord.FFmpegOpusAudio.from_probe("song.webm", method=custom_probe) source = await discord.FFmpegOpusAudio.from_probe("song.webm", method=custom_probe)
@@ -448,16 +382,10 @@ class FFmpegOpusAudio(FFmpegAudio):
executable = kwargs.get('executable') executable = kwargs.get('executable')
codec, bitrate = await cls.probe(source, method=method, executable=executable) codec, bitrate = await cls.probe(source, method=method, executable=executable)
return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore return cls(source, bitrate=bitrate, codec=codec, **kwargs)
@classmethod @classmethod
async def probe( async def probe(cls, source, *, method=None, executable=None):
cls,
source: str,
*,
method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None,
executable: Optional[str] = None,
) -> Tuple[Optional[str], Optional[int]]:
"""|coro| """|coro|
Probes the input source for bitrate and codec information. Probes the input source for bitrate and codec information.
@@ -480,7 +408,7 @@ class FFmpegOpusAudio(FFmpegAudio):
Returns Returns
--------- ---------
Optional[Tuple[Optional[:class:`str`], Optional[:class:`int`]]] Tuple[Optional[:class:`str`], Optional[:class:`int`]]
A 2-tuple with the codec and bitrate of the input source. A 2-tuple with the codec and bitrate of the input source.
""" """
@@ -506,26 +434,26 @@ class FFmpegOpusAudio(FFmpegAudio):
codec = bitrate = None codec = bitrate = None
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable)) # type: ignore codec, bitrate = await loop.run_in_executor(None, lambda: probefunc(source, executable))
except Exception: except Exception:
if not fallback: if not fallback:
_log.exception("Probe '%s' using '%s' failed", method, executable) log.exception("Probe '%s' using '%s' failed", method, executable)
return # type: ignore return
_log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable) log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable)
try: try:
codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable))
except Exception: except Exception:
_log.exception("Fallback probe using '%s' failed", executable) log.exception("Fallback probe using '%s' failed", executable)
else: else:
_log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate) log.info("Fallback probe found codec=%s, bitrate=%s", codec, bitrate)
else: else:
_log.info("Probe found codec=%s, bitrate=%s", codec, bitrate) log.info("Probe found codec=%s, bitrate=%s", codec, bitrate)
finally: finally:
return codec, bitrate return codec, bitrate
@staticmethod @staticmethod
def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: def _probe_codec_native(source, executable='ffmpeg'):
exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable
args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source]
output = subprocess.check_output(args, timeout=20) output = subprocess.check_output(args, timeout=20)
@@ -537,12 +465,12 @@ class FFmpegOpusAudio(FFmpegAudio):
codec = streamdata.get('codec_name') codec = streamdata.get('codec_name')
bitrate = int(streamdata.get('bit_rate', 0)) bitrate = int(streamdata.get('bit_rate', 0))
bitrate = max(round(bitrate/1000), 512) bitrate = max(round(bitrate/1000, 0), 512)
return codec, bitrate return codec, bitrate
@staticmethod @staticmethod
def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: def _probe_codec_fallback(source, executable='ffmpeg'):
args = [executable, '-hide_banner', '-i', source] args = [executable, '-hide_banner', '-i', source]
proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
out, _ = proc.communicate(timeout=20) out, _ = proc.communicate(timeout=20)
@@ -559,13 +487,13 @@ class FFmpegOpusAudio(FFmpegAudio):
return codec, bitrate return codec, bitrate
def read(self) -> bytes: def read(self):
return next(self._packet_iter, b'') return next(self._packet_iter, b'')
def is_opus(self) -> bool: def is_opus(self):
return True return True
class PCMVolumeTransformer(AudioSource, Generic[AT]): class PCMVolumeTransformer(AudioSource):
"""Transforms a previous :class:`AudioSource` to have volume controls. """Transforms a previous :class:`AudioSource` to have volume controls.
This does not work on audio sources that have :meth:`AudioSource.is_opus` This does not work on audio sources that have :meth:`AudioSource.is_opus`
@@ -587,53 +515,53 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]):
The audio source is opus encoded. The audio source is opus encoded.
""" """
def __init__(self, original: AT, volume: float = 1.0): def __init__(self, original, volume=1.0):
if not isinstance(original, AudioSource): if not isinstance(original, AudioSource):
raise TypeError(f'expected AudioSource not {original.__class__.__name__}.') raise TypeError(f'expected AudioSource not {original.__class__.__name__}.')
if original.is_opus(): if original.is_opus():
raise ClientException('AudioSource must not be Opus encoded.') raise ClientException('AudioSource must not be Opus encoded.')
self.original: AT = original self.original = original
self.volume = volume self.volume = volume
@property @property
def volume(self) -> float: def volume(self):
"""Retrieves or sets the volume as a floating point percentage (e.g. ``1.0`` for 100%).""" """Retrieves or sets the volume as a floating point percentage (e.g. ``1.0`` for 100%)."""
return self._volume return self._volume
@volume.setter @volume.setter
def volume(self, value: float) -> None: def volume(self, value):
self._volume = max(value, 0.0) self._volume = max(value, 0.0)
def cleanup(self) -> None: def cleanup(self):
self.original.cleanup() self.original.cleanup()
def read(self) -> bytes: def read(self):
ret = self.original.read() ret = self.original.read()
return audioop.mul(ret, 2, min(self._volume, 2.0)) return audioop.mul(ret, 2, min(self._volume, 2.0))
class AudioPlayer(threading.Thread): class AudioPlayer(threading.Thread):
DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 DELAY = OpusEncoder.FRAME_LENGTH / 1000.0
def __init__(self, source: AudioSource, client: VoiceClient, *, after=None): def __init__(self, source, client, *, after=None):
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.daemon: bool = True self.daemon = True
self.source: AudioSource = source self.source = source
self.client: VoiceClient = client self.client = client
self.after: Optional[Callable[[Optional[Exception]], Any]] = after self.after = after
self._end: threading.Event = threading.Event() self._end = threading.Event()
self._resumed: threading.Event = threading.Event() self._resumed = threading.Event()
self._resumed.set() # we are not paused self._resumed.set() # we are not paused
self._current_error: Optional[Exception] = None self._current_error = None
self._connected: threading.Event = client._connected self._connected = client._connected
self._lock: threading.Lock = threading.Lock() self._lock = threading.Lock()
if after is not None and not callable(after): if after is not None and not callable(after):
raise TypeError('Expected a callable for the "after" parameter.') raise TypeError('Expected a callable for the "after" parameter.')
def _do_run(self) -> None: def _do_run(self):
self.loops = 0 self.loops = 0
self._start = time.perf_counter() self._start = time.perf_counter()
@@ -668,7 +596,7 @@ class AudioPlayer(threading.Thread):
delay = max(0, self.DELAY + (next_time - time.perf_counter())) delay = max(0, self.DELAY + (next_time - time.perf_counter()))
time.sleep(delay) time.sleep(delay)
def run(self) -> None: def run(self):
try: try:
self._do_run() self._do_run()
except Exception as exc: except Exception as exc:
@@ -678,53 +606,53 @@ class AudioPlayer(threading.Thread):
self.source.cleanup() self.source.cleanup()
self._call_after() self._call_after()
def _call_after(self) -> None: def _call_after(self):
error = self._current_error error = self._current_error
if self.after is not None: if self.after is not None:
try: try:
self.after(error) self.after(error)
except Exception as exc: except Exception as exc:
_log.exception('Calling the after function failed.') log.exception('Calling the after function failed.')
exc.__context__ = error exc.__context__ = error
traceback.print_exception(type(exc), exc, exc.__traceback__) traceback.print_exception(type(exc), exc, exc.__traceback__)
elif error: elif error:
msg = f'Exception in voice thread {self.name}' msg = f'Exception in voice thread {self.name}'
_log.exception(msg, exc_info=error) log.exception(msg, exc_info=error)
print(msg, file=sys.stderr) print(msg, file=sys.stderr)
traceback.print_exception(type(error), error, error.__traceback__) traceback.print_exception(type(error), error, error.__traceback__)
def stop(self) -> None: def stop(self):
self._end.set() self._end.set()
self._resumed.set() self._resumed.set()
self._speak(False) self._speak(False)
def pause(self, *, update_speaking: bool = True) -> None: def pause(self, *, update_speaking=True):
self._resumed.clear() self._resumed.clear()
if update_speaking: if update_speaking:
self._speak(False) self._speak(False)
def resume(self, *, update_speaking: bool = True) -> None: def resume(self, *, update_speaking=True):
self.loops = 0 self.loops = 0
self._start = time.perf_counter() self._start = time.perf_counter()
self._resumed.set() self._resumed.set()
if update_speaking: if update_speaking:
self._speak(True) self._speak(True)
def is_playing(self) -> bool: def is_playing(self):
return self._resumed.is_set() and not self._end.is_set() return self._resumed.is_set() and not self._end.is_set()
def is_paused(self) -> bool: def is_paused(self):
return not self._end.is_set() and not self._resumed.is_set() return not self._end.is_set() and not self._resumed.is_set()
def _set_source(self, source: AudioSource) -> None: def _set_source(self, source):
with self._lock: with self._lock:
self.pause(update_speaking=False) self.pause(update_speaking=False)
self.source = source self.source = source
self.resume(update_speaking=False) self.resume(update_speaking=False)
def _speak(self, speaking: bool) -> None: def _speak(self, speaking):
try: try:
asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop) asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop)
except Exception as e: except Exception as e:
_log.info("Speaking call in player failed: %s", e) log.info("Speaking call in player failed: %s", e)

View File

View File

@@ -22,25 +22,6 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Set, List
if TYPE_CHECKING:
from .types.raw_models import (
MessageDeleteEvent,
BulkMessageDeleteEvent,
ReactionActionEvent,
MessageUpdateEvent,
ReactionClearEvent,
ReactionClearEmojiEvent,
IntegrationDeleteEvent
)
from .message import Message
from .partial_emoji import PartialEmoji
from .member import Member
__all__ = ( __all__ = (
'RawMessageDeleteEvent', 'RawMessageDeleteEvent',
'RawBulkMessageDeleteEvent', 'RawBulkMessageDeleteEvent',
@@ -48,16 +29,13 @@ __all__ = (
'RawReactionActionEvent', 'RawReactionActionEvent',
'RawReactionClearEvent', 'RawReactionClearEvent',
'RawReactionClearEmojiEvent', 'RawReactionClearEmojiEvent',
'RawIntegrationDeleteEvent',
) )
class _RawReprMixin: class _RawReprMixin:
def __repr__(self) -> str: def __repr__(self):
value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__) value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__)
return f'<{self.__class__.__name__} {value}>' return f'<{self.__class__.__name__} {value}>'
class RawMessageDeleteEvent(_RawReprMixin): class RawMessageDeleteEvent(_RawReprMixin):
"""Represents the event payload for a :func:`on_raw_message_delete` event. """Represents the event payload for a :func:`on_raw_message_delete` event.
@@ -75,15 +53,14 @@ class RawMessageDeleteEvent(_RawReprMixin):
__slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message') __slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message')
def __init__(self, data: MessageDeleteEvent) -> None: def __init__(self, data):
self.message_id: int = int(data['id']) self.message_id = int(data['id'])
self.channel_id: int = int(data['channel_id']) self.channel_id = int(data['channel_id'])
self.cached_message: Optional[Message] = None self.cached_message = None
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id = int(data['guild_id'])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id = None
class RawBulkMessageDeleteEvent(_RawReprMixin): class RawBulkMessageDeleteEvent(_RawReprMixin):
"""Represents the event payload for a :func:`on_raw_bulk_message_delete` event. """Represents the event payload for a :func:`on_raw_bulk_message_delete` event.
@@ -102,16 +79,15 @@ class RawBulkMessageDeleteEvent(_RawReprMixin):
__slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages') __slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages')
def __init__(self, data: BulkMessageDeleteEvent) -> None: def __init__(self, data):
self.message_ids: Set[int] = {int(x) for x in data.get('ids', [])} self.message_ids = {int(x) for x in data.get('ids', [])}
self.channel_id: int = int(data['channel_id']) self.channel_id = int(data['channel_id'])
self.cached_messages: List[Message] = [] self.cached_messages = []
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id = int(data['guild_id'])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id = None
class RawMessageUpdateEvent(_RawReprMixin): class RawMessageUpdateEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_message_edit` event. """Represents the payload for a :func:`on_raw_message_edit` event.
@@ -138,17 +114,16 @@ class RawMessageUpdateEvent(_RawReprMixin):
__slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message') __slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message')
def __init__(self, data: MessageUpdateEvent) -> None: def __init__(self, data):
self.message_id: int = int(data['id']) self.message_id = int(data['id'])
self.channel_id: int = int(data['channel_id']) self.channel_id = int(data['channel_id'])
self.data: MessageUpdateEvent = data self.data = data
self.cached_message: Optional[Message] = None self.cached_message = None
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id = int(data['guild_id'])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id = None
class RawReactionActionEvent(_RawReprMixin): class RawReactionActionEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_add` or """Represents the payload for a :func:`on_raw_reaction_add` or
@@ -182,19 +157,18 @@ class RawReactionActionEvent(_RawReprMixin):
__slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', __slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji',
'event_type', 'member') 'event_type', 'member')
def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None: def __init__(self, data, emoji, event_type):
self.message_id: int = int(data['message_id']) self.message_id = int(data['message_id'])
self.channel_id: int = int(data['channel_id']) self.channel_id = int(data['channel_id'])
self.user_id: int = int(data['user_id']) self.user_id = int(data['user_id'])
self.emoji: PartialEmoji = emoji self.emoji = emoji
self.event_type: str = event_type self.event_type = event_type
self.member: Optional[Member] = None self.member = None
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id = int(data['guild_id'])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id = None
class RawReactionClearEvent(_RawReprMixin): class RawReactionClearEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_clear` event. """Represents the payload for a :func:`on_raw_reaction_clear` event.
@@ -211,15 +185,14 @@ class RawReactionClearEvent(_RawReprMixin):
__slots__ = ('message_id', 'channel_id', 'guild_id') __slots__ = ('message_id', 'channel_id', 'guild_id')
def __init__(self, data: ReactionClearEvent) -> None: def __init__(self, data):
self.message_id: int = int(data['message_id']) self.message_id = int(data['message_id'])
self.channel_id: int = int(data['channel_id']) self.channel_id = int(data['channel_id'])
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id = int(data['guild_id'])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id = None
class RawReactionClearEmojiEvent(_RawReprMixin): class RawReactionClearEmojiEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_reaction_clear_emoji` event. """Represents the payload for a :func:`on_raw_reaction_clear_emoji` event.
@@ -240,39 +213,12 @@ class RawReactionClearEmojiEvent(_RawReprMixin):
__slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji') __slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji')
def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None: def __init__(self, data, emoji):
self.emoji: PartialEmoji = emoji self.emoji = emoji
self.message_id: int = int(data['message_id']) self.message_id = int(data['message_id'])
self.channel_id: int = int(data['channel_id']) self.channel_id = int(data['channel_id'])
try: try:
self.guild_id: Optional[int] = int(data['guild_id']) self.guild_id = int(data['guild_id'])
except KeyError: except KeyError:
self.guild_id: Optional[int] = None self.guild_id = None
class RawIntegrationDeleteEvent(_RawReprMixin):
"""Represents the payload for a :func:`on_raw_integration_delete` event.
.. versionadded:: 2.0
Attributes
-----------
integration_id: :class:`int`
The ID of the integration that got deleted.
application_id: Optional[:class:`int`]
The ID of the bot/OAuth2 application for this deleted integration.
guild_id: :class:`int`
The guild ID where the integration got deleted.
"""
__slots__ = ('integration_id', 'application_id', 'guild_id')
def __init__(self, data: IntegrationDeleteEvent) -> None:
self.integration_id: int = int(data['id'])
self.guild_id: int = int(data['guild_id'])
try:
self.application_id: Optional[int] = int(data['application_id'])
except KeyError:
self.application_id: Optional[int] = None

View File

@@ -22,22 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Any, TYPE_CHECKING, Union, Optional
from .iterators import ReactionIterator from .iterators import ReactionIterator
__all__ = ( __all__ = (
'Reaction', 'Reaction',
) )
if TYPE_CHECKING:
from .types.message import Reaction as ReactionPayload
from .message import Message
from .partial_emoji import PartialEmoji
from .emoji import Emoji
from .abc import Snowflake
class Reaction: class Reaction:
"""Represents a reaction to a message. """Represents a reaction to a message.
@@ -77,35 +67,35 @@ class Reaction:
""" """
__slots__ = ('message', 'count', 'emoji', 'me') __slots__ = ('message', 'count', 'emoji', 'me')
def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None): def __init__(self, *, message, data, emoji=None):
self.message: Message = message self.message = message
self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji']) self.emoji = emoji or message._state.get_reaction_emoji(data['emoji'])
self.count: int = data.get('count', 1) self.count = data.get('count', 1)
self.me: bool = data.get('me') self.me = data.get('me')
# TODO: typeguard @property
def is_custom_emoji(self) -> bool: def custom_emoji(self):
""":class:`bool`: If this is a custom emoji.""" """:class:`bool`: If this is a custom emoji."""
return not isinstance(self.emoji, str) return not isinstance(self.emoji, str)
def __eq__(self, other: Any) -> bool: def __eq__(self, other):
return isinstance(other, self.__class__) and other.emoji == self.emoji return isinstance(other, self.__class__) and other.emoji == self.emoji
def __ne__(self, other: Any) -> bool: def __ne__(self, other):
if isinstance(other, self.__class__): if isinstance(other, self.__class__):
return other.emoji != self.emoji return other.emoji != self.emoji
return True return True
def __hash__(self) -> int: def __hash__(self):
return hash(self.emoji) return hash(self.emoji)
def __str__(self) -> str: def __str__(self):
return str(self.emoji) return str(self.emoji)
def __repr__(self) -> str: def __repr__(self):
return f'<Reaction emoji={self.emoji!r} me={self.me} count={self.count}>' return '<Reaction emoji={0.emoji!r} me={0.me} count={0.count}>'.format(self)
async def remove(self, user: Snowflake) -> None: async def remove(self, user):
"""|coro| """|coro|
Remove the reaction by the provided :class:`User` from the message. Remove the reaction by the provided :class:`User` from the message.
@@ -133,7 +123,7 @@ class Reaction:
await self.message.remove_reaction(self.emoji, user) await self.message.remove_reaction(self.emoji, user)
async def clear(self) -> None: async def clear(self):
"""|coro| """|coro|
Clears this reaction from the message. Clears this reaction from the message.
@@ -155,7 +145,7 @@ class Reaction:
""" """
await self.message.clear_reaction(self.emoji) await self.message.clear_reaction(self.emoji)
def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> ReactionIterator: def users(self, limit=None, after=None):
"""Returns an :class:`AsyncIterator` representing the users that have reacted to the message. """Returns an :class:`AsyncIterator` representing the users that have reacted to the message.
The ``after`` parameter must represent a member The ``after`` parameter must represent a member
@@ -168,22 +158,22 @@ class Reaction:
# I do not actually recommend doing this. # I do not actually recommend doing this.
async for user in reaction.users(): async for user in reaction.users():
await channel.send(f'{user} has reacted with {reaction.emoji}!') await channel.send('{0} has reacted with {1.emoji}!'.format(user, reaction))
Flattening into a list: :: Flattening into a list: ::
users = await reaction.users().flatten() users = await reaction.users().flatten()
# users is now a list of User... # users is now a list of User...
winner = random.choice(users) winner = random.choice(users)
await channel.send(f'{winner} has won the raffle.') await channel.send('{} has won the raffle.'.format(winner))
Parameters Parameters
------------ ------------
limit: Optional[:class:`int`] limit: :class:`int`
The maximum number of results to return. The maximum number of results to return.
If not provided, returns all the users who If not provided, returns all the users who
reacted to the message. reacted to the message.
after: Optional[:class:`abc.Snowflake`] after: :class:`abc.Snowflake`
For pagination, reactions are sorted by member. For pagination, reactions are sorted by member.
Raises Raises
@@ -200,8 +190,8 @@ class Reaction:
if the member has left the guild. if the member has left the guild.
""" """
if not isinstance(self.emoji, str): if self.custom_emoji:
emoji = f'{self.emoji.name}:{self.emoji.id}' emoji = '{0.name}:{0.id}'.format(self.emoji)
else: else:
emoji = self.emoji emoji = self.emoji

View File

@@ -22,32 +22,17 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Any, Dict, List, Optional, TypeVar, Union, overload, TYPE_CHECKING
from .permissions import Permissions from .permissions import Permissions
from .errors import InvalidArgument from .errors import InvalidArgument
from .colour import Colour from .colour import Colour
from .mixins import Hashable from .mixins import Hashable
from .utils import snowflake_time, _get_as_snowflake, MISSING from .utils import snowflake_time, _get_as_snowflake
__all__ = ( __all__ = (
'RoleTags', 'RoleTags',
'Role', 'Role',
) )
if TYPE_CHECKING:
import datetime
from .types.role import (
Role as RolePayload,
RoleTags as RoleTagPayload,
)
from .types.guild import RolePositionUpdate
from .guild import Guild
from .member import Member
from .state import ConnectionState
class RoleTags: class RoleTags:
"""Represents tags on a role. """Represents tags on a role.
@@ -67,42 +52,32 @@ class RoleTags:
The integration ID that manages the role. The integration ID that manages the role.
""" """
__slots__ = ( __slots__ = ('bot_id', 'integration_id', '_premium_subscriber',)
'bot_id',
'integration_id',
'_premium_subscriber',
)
def __init__(self, data: RoleTagPayload): def __init__(self, data):
self.bot_id: Optional[int] = _get_as_snowflake(data, 'bot_id') self.bot_id = _get_as_snowflake(data, 'bot_id')
self.integration_id: Optional[int] = _get_as_snowflake(data, 'integration_id') self.integration_id = _get_as_snowflake(data, 'integration_id')
# NOTE: The API returns "null" for this if it's valid, which corresponds to None. # NOTE: The API returns "null" for this if it's valid, which corresponds to None.
# This is different from other fields where "null" means "not there". # This is different from other fields where "null" means "not there".
# So in this case, a value of None is the same as True. # So in this case, a value of None is the same as True.
# Which means we would need a different sentinel. # Which means we would need a different sentinel. For this purpose I used ellipsis.
self._premium_subscriber: Optional[Any] = data.get('premium_subscriber', MISSING) self._premium_subscriber = data.get('premium_subscriber', ...)
def is_bot_managed(self) -> bool: def is_bot_managed(self):
""":class:`bool`: Whether the role is associated with a bot.""" """:class:`bool`: Whether the role is associated with a bot."""
return self.bot_id is not None return self.bot_id is not None
def is_premium_subscriber(self) -> bool: def is_premium_subscriber(self):
""":class:`bool`: Whether the role is the premium subscriber, AKA "boost", role for the guild.""" """:class:`bool`: Whether the role is the premium subscriber, AKA "boost", role for the guild."""
return self._premium_subscriber is None return self._premium_subscriber is None
def is_integration(self) -> bool: def is_integration(self):
""":class:`bool`: Whether the role is managed by an integration.""" """:class:`bool`: Whether the role is managed by an integration."""
return self.integration_id is not None return self.integration_id is not None
def __repr__(self) -> str: def __repr__(self):
return ( return '<RoleTags bot_id={0.bot_id} integration_id={0.integration_id} ' \
f'<RoleTags bot_id={self.bot_id} integration_id={self.integration_id} ' 'premium_subscriber={1}>'.format(self, self.is_premium_subscriber())
f'premium_subscriber={self.is_premium_subscriber()}>'
)
R = TypeVar('R', bound='Role')
class Role(Hashable): class Role(Hashable):
"""Represents a Discord role in a :class:`Guild`. """Represents a Discord role in a :class:`Guild`.
@@ -154,15 +129,6 @@ class Role(Hashable):
position: :class:`int` position: :class:`int`
The position of the role. This number is usually positive. The bottom The position of the role. This number is usually positive. The bottom
role has a position of 0. role has a position of 0.
.. warning::
Multiple roles can have the same position number. As a consequence
of this, comparing via role position is prone to subtle bugs if
checking for role hierarchy. The recommended and correct way to
compare for roles in the hierarchy is using the comparison
operators on the role objects themselves.
managed: :class:`bool` managed: :class:`bool`
Indicates if the role is managed by the guild through some form of Indicates if the role is managed by the guild through some form of
integrations such as Twitch. integrations such as Twitch.
@@ -172,33 +138,25 @@ class Role(Hashable):
The role tags associated with this role. The role tags associated with this role.
""" """
__slots__ = ( __slots__ = ('id', 'name', '_permissions', '_colour', 'position',
'id', 'managed', 'mentionable', 'hoist', 'guild', 'tags', '_state')
'name',
'_permissions',
'_colour',
'position',
'managed',
'mentionable',
'hoist',
'guild',
'tags',
'_state',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload): def __init__(self, *, guild, state, data):
self.guild: Guild = guild self.guild = guild
self._state: ConnectionState = state self._state = state
self.id: int = int(data['id']) self.id = int(data['id'])
self._update(data) self._update(data)
def __str__(self) -> str: def __str__(self):
return self.name return self.name
def __repr__(self) -> str: def __int__(self):
return f'<Role id={self.id} name={self.name!r}>' return self.id
def __lt__(self: R, other: R) -> bool: def __repr__(self):
return '<Role id={0.id} name={0.name!r}>'.format(self)
def __lt__(self, other):
if not isinstance(other, Role) or not isinstance(self, Role): if not isinstance(other, Role) or not isinstance(self, Role):
return NotImplemented return NotImplemented
@@ -219,96 +177,87 @@ class Role(Hashable):
return False return False
def __le__(self: R, other: R) -> bool: def __le__(self, other):
r = Role.__lt__(other, self) r = Role.__lt__(other, self)
if r is NotImplemented: if r is NotImplemented:
return NotImplemented return NotImplemented
return not r return not r
def __gt__(self: R, other: R) -> bool: def __gt__(self, other):
return Role.__lt__(other, self) return Role.__lt__(other, self)
def __ge__(self: R, other: R) -> bool: def __ge__(self, other):
r = Role.__lt__(self, other) r = Role.__lt__(self, other)
if r is NotImplemented: if r is NotImplemented:
return NotImplemented return NotImplemented
return not r return not r
def _update(self, data: RolePayload): def _update(self, data):
self.name: str = data['name'] self.name = data['name']
self._permissions: int = int(data.get('permissions', 0)) self._permissions = int(data.get('permissions_new', 0))
self.position: int = data.get('position', 0) self.position = data.get('position', 0)
self._colour: int = data.get('color', 0) self._colour = data.get('color', 0)
self.hoist: bool = data.get('hoist', False) self.hoist = data.get('hoist', False)
self.managed: bool = data.get('managed', False) self.managed = data.get('managed', False)
self.mentionable: bool = data.get('mentionable', False) self.mentionable = data.get('mentionable', False)
self.tags: Optional[RoleTags]
try: try:
self.tags = RoleTags(data['tags']) self.tags = RoleTags(data['tags'])
except KeyError: except KeyError:
self.tags = None self.tags = None
def is_default(self) -> bool: def is_default(self):
""":class:`bool`: Checks if the role is the default role.""" """:class:`bool`: Checks if the role is the default role."""
return self.guild.id == self.id return self.guild.id == self.id
def is_bot_managed(self) -> bool: def is_bot_managed(self):
""":class:`bool`: Whether the role is associated with a bot. """:class:`bool`: Whether the role is associated with a bot.
.. versionadded:: 1.6 .. versionadded:: 1.6
""" """
return self.tags is not None and self.tags.is_bot_managed() return self.tags is not None and self.tags.is_bot_managed()
def is_premium_subscriber(self) -> bool: def is_premium_subscriber(self):
""":class:`bool`: Whether the role is the premium subscriber, AKA "boost", role for the guild. """:class:`bool`: Whether the role is the premium subscriber, AKA "boost", role for the guild.
.. versionadded:: 1.6 .. versionadded:: 1.6
""" """
return self.tags is not None and self.tags.is_premium_subscriber() return self.tags is not None and self.tags.is_premium_subscriber()
def is_integration(self) -> bool: def is_integration(self):
""":class:`bool`: Whether the role is managed by an integration. """:class:`bool`: Whether the role is managed by an integration.
.. versionadded:: 1.6 .. versionadded:: 1.6
""" """
return self.tags is not None and self.tags.is_integration() return self.tags is not None and self.tags.is_integration()
def is_assignable(self) -> bool:
""":class:`bool`: Whether the role is able to be assigned or removed by the bot.
.. versionadded:: 2.0
"""
me = self.guild.me
return not self.is_default() and not self.managed and (me.top_role > self or me.id == self.guild.owner_id)
@property @property
def permissions(self) -> Permissions: def permissions(self):
""":class:`Permissions`: Returns the role's permissions.""" """:class:`Permissions`: Returns the role's permissions."""
return Permissions(self._permissions) return Permissions(self._permissions)
@property @property
def colour(self) -> Colour: def colour(self):
""":class:`Colour`: Returns the role colour. An alias exists under ``color``.""" """:class:`Colour`: Returns the role colour. An alias exists under ``color``."""
return Colour(self._colour) return Colour(self._colour)
@property @property
def color(self) -> Colour: def color(self):
""":class:`Colour`: Returns the role color. An alias exists under ``colour``.""" """:class:`Colour`: Returns the role color. An alias exists under ``colour``."""
return self.colour return self.colour
@property @property
def created_at(self) -> datetime.datetime: def created_at(self):
""":class:`datetime.datetime`: Returns the role's creation time in UTC.""" """:class:`datetime.datetime`: Returns the role's creation time in UTC."""
return snowflake_time(self.id) return snowflake_time(self.id)
@property @property
def mention(self) -> str: def mention(self):
""":class:`str`: Returns a string that allows you to mention a role.""" """:class:`str`: Returns a string that allows you to mention a role."""
return f'<@&{self.id}>' return f'<@&{self.id}>'
@property @property
def members(self) -> List[Member]: def members(self):
"""List[:class:`Member`]: Returns all the members with this role.""" """List[:class:`Member`]: Returns all the members with this role."""
all_members = self.guild.members all_members = self.guild.members
if self.is_default(): if self.is_default():
@@ -317,7 +266,7 @@ class Role(Hashable):
role_id = self.id role_id = self.id
return [member for member in all_members if member._roles.has(role_id)] return [member for member in all_members if member._roles.has(role_id)]
async def _move(self, position: int, reason: Optional[str]) -> None: async def _move(self, position, reason):
if position <= 0: if position <= 0:
raise InvalidArgument("Cannot move role to position 0 or below") raise InvalidArgument("Cannot move role to position 0 or below")
@@ -337,21 +286,10 @@ class Role(Hashable):
else: else:
roles.append(self.id) roles.append(self.id)
payload: List[RolePositionUpdate] = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)] payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)]
await http.move_role_position(self.guild.id, payload, reason=reason) await http.move_role_position(self.guild.id, payload, reason=reason)
async def edit( async def edit(self, *, reason=None, **fields):
self,
*,
name: str = MISSING,
permissions: Permissions = MISSING,
colour: Union[Colour, int] = MISSING,
color: Union[Colour, int] = MISSING,
hoist: bool = MISSING,
mentionable: bool = MISSING,
position: int = MISSING,
reason: Optional[str] = MISSING,
) -> Optional[Role]:
"""|coro| """|coro|
Edits the role. Edits the role.
@@ -364,9 +302,6 @@ class Role(Hashable):
.. versionchanged:: 1.4 .. versionchanged:: 1.4
Can now pass ``int`` to ``colour`` keyword-only parameter. Can now pass ``int`` to ``colour`` keyword-only parameter.
.. versionchanged:: 2.0
Edits are no longer in-place, the newly edited role is returned instead.
Parameters Parameters
----------- -----------
name: :class:`str` name: :class:`str`
@@ -394,41 +329,33 @@ class Role(Hashable):
InvalidArgument InvalidArgument
An invalid position was given or the default An invalid position was given or the default
role was asked to be moved. role was asked to be moved.
Returns
--------
:class:`Role`
The newly edited role.
""" """
if position is not MISSING:
position = fields.get('position')
if position is not None:
await self._move(position, reason=reason) await self._move(position, reason=reason)
self.position = position
payload: Dict[str, Any] = {} try:
if color is not MISSING: colour = fields['colour']
colour = color except KeyError:
colour = fields.get('color', self.colour)
if colour is not MISSING: if isinstance(colour, int):
if isinstance(colour, int): colour = Colour(value=colour)
payload['color'] = colour
else:
payload['color'] = colour.value
if name is not MISSING: payload = {
payload['name'] = name 'name': fields.get('name', self.name),
'permissions': str(fields.get('permissions', self.permissions).value),
if permissions is not MISSING: 'color': colour.value,
payload['permissions'] = permissions.value 'hoist': fields.get('hoist', self.hoist),
'mentionable': fields.get('mentionable', self.mentionable)
if hoist is not MISSING: }
payload['hoist'] = hoist
if mentionable is not MISSING:
payload['mentionable'] = mentionable
data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload) data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload)
return Role(guild=self.guild, data=data, state=self._state) self._update(data)
async def delete(self, *, reason: Optional[str] = None) -> None: async def delete(self, *, reason=None):
"""|coro| """|coro|
Deletes the role. Deletes the role.

View File

@@ -22,9 +22,8 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import asyncio import asyncio
import itertools
import logging import logging
import aiohttp import aiohttp
@@ -35,30 +34,22 @@ from .backoff import ExponentialBackoff
from .gateway import * from .gateway import *
from .errors import ( from .errors import (
ClientException, ClientException,
InvalidArgument,
HTTPException, HTTPException,
GatewayNotFound, GatewayNotFound,
ConnectionClosed, ConnectionClosed,
PrivilegedIntentsRequired, PrivilegedIntentsRequired,
) )
from . import utils
from .enums import Status from .enums import Status
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar
if TYPE_CHECKING:
from .gateway import DiscordWebSocket
from .activity import BaseActivity
from .enums import Status
EI = TypeVar('EI', bound='EventItem')
__all__ = ( __all__ = (
'AutoShardedClient', 'AutoShardedClient',
'ShardInfo', 'ShardInfo',
) )
_log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class EventType: class EventType:
close = 0 close = 0
@@ -68,41 +59,39 @@ class EventType:
terminate = 4 terminate = 4
clean_close = 5 clean_close = 5
class EventItem: class EventItem:
__slots__ = ('type', 'shard', 'error') __slots__ = ('type', 'shard', 'error')
def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None: def __init__(self, etype, shard, error):
self.type: int = etype self.type = etype
self.shard: Optional['Shard'] = shard self.shard = shard
self.error: Optional[Exception] = error self.error = error
def __lt__(self: EI, other: EI) -> bool: def __lt__(self, other):
if not isinstance(other, EventItem): if not isinstance(other, EventItem):
return NotImplemented return NotImplemented
return self.type < other.type return self.type < other.type
def __eq__(self: EI, other: EI) -> bool: def __eq__(self, other):
if not isinstance(other, EventItem): if not isinstance(other, EventItem):
return NotImplemented return NotImplemented
return self.type == other.type return self.type == other.type
def __hash__(self) -> int: def __hash__(self):
return hash(self.type) return hash(self.type)
class Shard: class Shard:
def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None: def __init__(self, ws, client, queue_put):
self.ws: DiscordWebSocket = ws self.ws = ws
self._client: Client = client self._client = client
self._dispatch: Callable[..., None] = client.dispatch self._dispatch = client.dispatch
self._queue_put: Callable[[EventItem], None] = queue_put self._queue_put = queue_put
self.loop: asyncio.AbstractEventLoop = self._client.loop self.loop = self._client.loop
self._disconnect: bool = False self._disconnect = False
self._reconnect = client._reconnect self._reconnect = client._reconnect
self._backoff: ExponentialBackoff = ExponentialBackoff() self._backoff = ExponentialBackoff()
self._task: Optional[asyncio.Task] = None self._task = None
self._handled_exceptions: Tuple[Type[Exception], ...] = ( self._handled_exceptions = (
OSError, OSError,
HTTPException, HTTPException,
GatewayNotFound, GatewayNotFound,
@@ -112,26 +101,25 @@ class Shard:
) )
@property @property
def id(self) -> int: def id(self):
# DiscordWebSocket.shard_id is set in the from_client classmethod return self.ws.shard_id
return self.ws.shard_id # type: ignore
def launch(self) -> None: def launch(self):
self._task = self.loop.create_task(self.worker()) self._task = self.loop.create_task(self.worker())
def _cancel_task(self) -> None: def _cancel_task(self):
if self._task is not None and not self._task.done(): if self._task is not None and not self._task.done():
self._task.cancel() self._task.cancel()
async def close(self) -> None: async def close(self):
self._cancel_task() self._cancel_task()
await self.ws.close(code=1000) await self.ws.close(code=1000)
async def disconnect(self) -> None: async def disconnect(self):
await self.close() await self.close()
self._dispatch('shard_disconnect', self.id) self._dispatch('shard_disconnect', self.id)
async def _handle_disconnect(self, e: Exception) -> None: async def _handle_disconnect(self, e):
self._dispatch('disconnect') self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id) self._dispatch('shard_disconnect', self.id)
if not self._reconnect: if not self._reconnect:
@@ -156,11 +144,11 @@ class Shard:
return return
retry = self._backoff.delay() retry = self._backoff.delay()
_log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e) log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e)
await asyncio.sleep(retry) await asyncio.sleep(retry)
self._queue_put(EventItem(EventType.reconnect, self, e)) self._queue_put(EventItem(EventType.reconnect, self, e))
async def worker(self) -> None: async def worker(self):
while not self._client.is_closed(): while not self._client.is_closed():
try: try:
await self.ws.poll_event() await self.ws.poll_event()
@@ -177,19 +165,14 @@ class Shard:
self._queue_put(EventItem(EventType.terminate, self, e)) self._queue_put(EventItem(EventType.terminate, self, e))
break break
async def reidentify(self, exc: ReconnectWebSocket) -> None: async def reidentify(self, exc):
self._cancel_task() self._cancel_task()
self._dispatch('disconnect') self._dispatch('disconnect')
self._dispatch('shard_disconnect', self.id) self._dispatch('shard_disconnect', self.id)
_log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
try: try:
coro = DiscordWebSocket.from_client( coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
self._client, session=self.ws.session_id, sequence=self.ws.sequence)
resume=exc.resume,
shard_id=self.id,
session=self.ws.session_id,
sequence=self.ws.sequence,
)
self.ws = await asyncio.wait_for(coro, timeout=60.0) self.ws = await asyncio.wait_for(coro, timeout=60.0)
except self._handled_exceptions as e: except self._handled_exceptions as e:
await self._handle_disconnect(e) await self._handle_disconnect(e)
@@ -200,7 +183,7 @@ class Shard:
else: else:
self.launch() self.launch()
async def reconnect(self) -> None: async def reconnect(self):
self._cancel_task() self._cancel_task()
try: try:
coro = DiscordWebSocket.from_client(self._client, shard_id=self.id) coro = DiscordWebSocket.from_client(self._client, shard_id=self.id)
@@ -214,7 +197,6 @@ class Shard:
else: else:
self.launch() self.launch()
class ShardInfo: class ShardInfo:
"""A class that gives information and control over a specific shard. """A class that gives information and control over a specific shard.
@@ -233,16 +215,16 @@ class ShardInfo:
__slots__ = ('_parent', 'id', 'shard_count') __slots__ = ('_parent', 'id', 'shard_count')
def __init__(self, parent: Shard, shard_count: Optional[int]) -> None: def __init__(self, parent, shard_count):
self._parent: Shard = parent self._parent = parent
self.id: int = parent.id self.id = parent.id
self.shard_count: Optional[int] = shard_count self.shard_count = shard_count
def is_closed(self) -> bool: def is_closed(self):
""":class:`bool`: Whether the shard connection is currently closed.""" """:class:`bool`: Whether the shard connection is currently closed."""
return not self._parent.ws.open return not self._parent.ws.open
async def disconnect(self) -> None: async def disconnect(self):
"""|coro| """|coro|
Disconnects a shard. When this is called, the shard connection will no Disconnects a shard. When this is called, the shard connection will no
@@ -255,7 +237,7 @@ class ShardInfo:
await self._parent.disconnect() await self._parent.disconnect()
async def reconnect(self) -> None: async def reconnect(self):
"""|coro| """|coro|
Disconnects and then connects the shard again. Disconnects and then connects the shard again.
@@ -264,7 +246,7 @@ class ShardInfo:
await self._parent.disconnect() await self._parent.disconnect()
await self._parent.reconnect() await self._parent.reconnect()
async def connect(self) -> None: async def connect(self):
"""|coro| """|coro|
Connects a shard. If the shard is already connected this does nothing. Connects a shard. If the shard is already connected this does nothing.
@@ -275,11 +257,11 @@ class ShardInfo:
await self._parent.reconnect() await self._parent.reconnect()
@property @property
def latency(self) -> float: def latency(self):
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard.""" """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds for this shard."""
return self._parent.ws.latency return self._parent.ws.latency
def is_ws_ratelimited(self) -> bool: def is_ws_ratelimited(self):
""":class:`bool`: Whether the websocket is currently rate limited. """:class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members This can be useful to know when deciding whether you should query members
@@ -289,7 +271,6 @@ class ShardInfo:
""" """
return self._parent.ws.is_ratelimited() return self._parent.ws.is_ratelimited()
class AutoShardedClient(Client): class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications """A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single of sharding for the user into a more manageable and transparent single
@@ -316,13 +297,9 @@ class AutoShardedClient(Client):
shard_ids: Optional[List[:class:`int`]] shard_ids: Optional[List[:class:`int`]]
An optional list of shard_ids to launch the shards with. An optional list of shard_ids to launch the shards with.
""" """
def __init__(self, *args, loop=None, **kwargs):
if TYPE_CHECKING:
_connection: AutoShardedConnectionState
def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None:
kwargs.pop('shard_id', None) kwargs.pop('shard_id', None)
self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) self.shard_ids = kwargs.pop('shard_ids', None)
super().__init__(*args, loop=loop, **kwargs) super().__init__(*args, loop=loop, **kwargs)
if self.shard_ids is not None: if self.shard_ids is not None:
@@ -338,24 +315,18 @@ class AutoShardedClient(Client):
self._connection._get_client = lambda: self self._connection._get_client = lambda: self
self.__queue = asyncio.PriorityQueue() self.__queue = asyncio.PriorityQueue()
def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: def _get_websocket(self, guild_id=None, *, shard_id=None):
if shard_id is None: if shard_id is None:
# guild_id won't be None if shard_id is None and shard_count won't be None here shard_id = (guild_id >> 22) % self.shard_count
shard_id = (guild_id >> 22) % self.shard_count # type: ignore
return self.__shards[shard_id].ws return self.__shards[shard_id].ws
def _get_state(self, **options: Any) -> AutoShardedConnectionState: def _get_state(self, **options):
return AutoShardedConnectionState( return AutoShardedConnectionState(dispatch=self.dispatch,
dispatch=self.dispatch, handlers=self._handlers,
handlers=self._handlers, hooks=self._hooks, http=self.http, loop=self.loop, **options)
hooks=self._hooks,
http=self.http,
loop=self.loop,
**options,
)
@property @property
def latency(self) -> float: def latency(self):
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This operates similarly to :meth:`Client.latency` except it uses the average This operates similarly to :meth:`Client.latency` except it uses the average
@@ -367,14 +338,14 @@ class AutoShardedClient(Client):
return sum(latency for _, latency in self.latencies) / len(self.__shards) return sum(latency for _, latency in self.latencies) / len(self.__shards)
@property @property
def latencies(self) -> List[Tuple[int, float]]: def latencies(self):
"""List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds. """List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This returns a list of tuples with elements ``(shard_id, latency)``. This returns a list of tuples with elements ``(shard_id, latency)``.
""" """
return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()] return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()]
def get_shard(self, shard_id: int) -> Optional[ShardInfo]: def get_shard(self, shard_id):
"""Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found.""" """Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found."""
try: try:
parent = self.__shards[shard_id] parent = self.__shards[shard_id]
@@ -384,16 +355,52 @@ class AutoShardedClient(Client):
return ShardInfo(parent, self.shard_count) return ShardInfo(parent, self.shard_count)
@property @property
def shards(self) -> Dict[int, ShardInfo]: def shards(self):
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()} return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }
async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None: @utils.deprecated('Guild.chunk')
async def request_offline_members(self, *guilds):
r"""|coro|
Requests previously offline members from the guild to be filled up
into the :attr:`Guild.members` cache. This function is usually not
called. It should only be used if you have the ``fetch_offline_members``
parameter set to ``False``.
When the client logs on and connects to the websocket, Discord does
not provide the library with offline members if the number of members
in the guild is larger than 250. You can check if a guild is large
if :attr:`Guild.large` is ``True``.
.. warning::
This method is deprecated. Use :meth:`Guild.chunk` instead.
Parameters
-----------
\*guilds: :class:`Guild`
An argument list of guilds to request offline members for.
Raises
-------
InvalidArgument
If any guild is unavailable in the collection.
"""
if any(g.unavailable for g in guilds):
raise InvalidArgument('An unavailable or non-large guild was passed.')
_guilds = sorted(guilds, key=lambda g: g.shard_id)
for shard_id, sub_guilds in itertools.groupby(_guilds, key=lambda g: g.shard_id):
for guild in sub_guilds:
await self._connection.chunk_guild(guild)
async def launch_shard(self, gateway, shard_id, *, initial=False):
try: try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
ws = await asyncio.wait_for(coro, timeout=180.0) ws = await asyncio.wait_for(coro, timeout=180.0)
except Exception: except Exception:
_log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id) log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
await asyncio.sleep(5.0) await asyncio.sleep(5.0)
return await self.launch_shard(gateway, shard_id) return await self.launch_shard(gateway, shard_id)
@@ -401,7 +408,7 @@ class AutoShardedClient(Client):
self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait) self.__shards[shard_id] = ret = Shard(ws, self, self.__queue.put_nowait)
ret.launch() ret.launch()
async def launch_shards(self) -> None: async def launch_shards(self):
if self.shard_count is None: if self.shard_count is None:
self.shard_count, gateway = await self.http.get_bot_gateway() self.shard_count, gateway = await self.http.get_bot_gateway()
else: else:
@@ -418,7 +425,7 @@ class AutoShardedClient(Client):
self._connection.shards_launched.set() self._connection.shards_launched.set()
async def connect(self, *, reconnect: bool = True) -> None: async def connect(self, *, reconnect=True):
self._reconnect = reconnect self._reconnect = reconnect
await self.launch_shards() await self.launch_shards()
@@ -442,7 +449,7 @@ class AutoShardedClient(Client):
elif item.type == EventType.clean_close: elif item.type == EventType.clean_close:
return return
async def close(self) -> None: async def close(self):
"""|coro| """|coro|
Closes the connection to Discord. Closes the connection to Discord.
@@ -454,7 +461,7 @@ class AutoShardedClient(Client):
for vc in self.voice_clients: for vc in self.voice_clients:
try: try:
await vc.disconnect(force=True) await vc.disconnect()
except Exception: except Exception:
pass pass
@@ -465,13 +472,7 @@ class AutoShardedClient(Client):
await self.http.close() await self.http.close()
self.__queue.put_nowait(EventItem(EventType.clean_close, None, None)) self.__queue.put_nowait(EventItem(EventType.clean_close, None, None))
async def change_presence( async def change_presence(self, *, activity=None, status=None, afk=False, shard_id=None):
self,
*,
activity: Optional[BaseActivity] = None,
status: Optional[Status] = None,
shard_id: int = None,
) -> None:
"""|coro| """|coro|
Changes the client's presence. Changes the client's presence.
@@ -481,9 +482,6 @@ class AutoShardedClient(Client):
game = discord.Game("with the API") game = discord.Game("with the API")
await client.change_presence(status=discord.Status.idle, activity=game) await client.change_presence(status=discord.Status.idle, activity=game)
.. versionchanged:: 2.0
Removed the ``afk`` keyword-only parameter.
Parameters Parameters
---------- ----------
activity: Optional[:class:`BaseActivity`] activity: Optional[:class:`BaseActivity`]
@@ -491,6 +489,10 @@ class AutoShardedClient(Client):
status: Optional[:class:`Status`] status: Optional[:class:`Status`]
Indicates what status to change to. If ``None``, then Indicates what status to change to. If ``None``, then
:attr:`Status.online` is used. :attr:`Status.online` is used.
afk: :class:`bool`
Indicates if you are going AFK. This allows the discord
client to know how to handle push notifications better
for you in case you are actually idle and not lying.
shard_id: Optional[:class:`int`] shard_id: Optional[:class:`int`]
The shard_id to change the presence to. If not specified The shard_id to change the presence to. If not specified
or ``None``, then it will change the presence of every or ``None``, then it will change the presence of every
@@ -503,23 +505,23 @@ class AutoShardedClient(Client):
""" """
if status is None: if status is None:
status_value = 'online' status = 'online'
status_enum = Status.online status_enum = Status.online
elif status is Status.offline: elif status is Status.offline:
status_value = 'invisible' status = 'invisible'
status_enum = Status.offline status_enum = Status.offline
else: else:
status_enum = status status_enum = status
status_value = str(status) status = str(status)
if shard_id is None: if shard_id is None:
for shard in self.__shards.values(): for shard in self.__shards.values():
await shard.ws.change_presence(activity=activity, status=status_value) await shard.ws.change_presence(activity=activity, status=status, afk=afk)
guilds = self._connection.guilds guilds = self._connection.guilds
else: else:
shard = self.__shards[shard_id] shard = self.__shards[shard_id]
await shard.ws.change_presence(activity=activity, status=status_value) await shard.ws.change_presence(activity=activity, status=status, afk=afk)
guilds = [g for g in self._connection.guilds if g.shard_id == shard_id] guilds = [g for g in self._connection.guilds if g.shard_id == shard_id]
activities = () if activity is None else (activity,) activities = () if activity is None else (activity,)
@@ -528,11 +530,10 @@ class AutoShardedClient(Client):
if me is None: if me is None:
continue continue
# Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...] me.activities = activities
me.activities = activities # type: ignore
me.status = status_enum me.status = status_enum
def is_ws_ratelimited(self) -> bool: def is_ws_ratelimited(self):
""":class:`bool`: Whether the websocket is currently rate limited. """:class:`bool`: Whether the websocket is currently rate limited.
This can be useful to know when deciding whether you should query members This can be useful to know when deciding whether you should query members

View File

@@ -1,176 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from .utils import MISSING, cached_slot_property
from .mixins import Hashable
from .errors import InvalidArgument
from .enums import StagePrivacyLevel, try_enum
__all__ = (
'StageInstance',
)
if TYPE_CHECKING:
from .types.channel import StageInstance as StageInstancePayload
from .state import ConnectionState
from .channel import StageChannel
from .guild import Guild
class StageInstance(Hashable):
"""Represents a stage instance of a stage channel in a guild.
.. versionadded:: 2.0
.. container:: operations
.. describe:: x == y
Checks if two stage instances are equal.
.. describe:: x != y
Checks if two stage instances are not equal.
.. describe:: hash(x)
Returns the stage instance's hash.
Attributes
-----------
id: :class:`int`
The stage instance's ID.
guild: :class:`Guild`
The guild that the stage instance is running in.
channel_id: :class:`int`
The ID of the channel that the stage instance is running in.
topic: :class:`str`
The topic of the stage instance.
privacy_level: :class:`StagePrivacyLevel`
The privacy level of the stage instance.
discoverable_disabled: :class:`bool`
Whether discoverability for the stage instance is disabled.
"""
__slots__ = (
'_state',
'id',
'guild',
'channel_id',
'topic',
'privacy_level',
'discoverable_disabled',
'_cs_channel',
)
def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None:
self._state = state
self.guild = guild
self._update(data)
def _update(self, data: StageInstancePayload):
self.id: int = int(data['id'])
self.channel_id: int = int(data['channel_id'])
self.topic: str = data['topic']
self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data['privacy_level'])
self.discoverable_disabled: bool = data.get('discoverable_disabled', False)
def __repr__(self) -> str:
return f'<StageInstance id={self.id} guild={self.guild!r} channel_id={self.channel_id} topic={self.topic!r}>'
@cached_slot_property('_cs_channel')
def channel(self) -> Optional[StageChannel]:
"""Optional[:class:`StageChannel`]: The channel that stage instance is running in."""
# the returned channel will always be a StageChannel or None
return self._state.get_channel(self.channel_id) # type: ignore
def is_public(self) -> bool:
return self.privacy_level is StagePrivacyLevel.public
async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None:
"""|coro|
Edits the stage instance.
You must have the :attr:`~Permissions.manage_channels` permission to
use this.
Parameters
-----------
topic: :class:`str`
The stage instance's new topic.
privacy_level: :class:`StagePrivacyLevel`
The stage instance's new privacy level.
reason: :class:`str`
The reason the stage instance was edited. Shows up on the audit log.
Raises
------
InvalidArgument
If the ``privacy_level`` parameter is not the proper type.
Forbidden
You do not have permissions to edit the stage instance.
HTTPException
Editing a stage instance failed.
"""
payload = {}
if topic is not MISSING:
payload['topic'] = topic
if privacy_level is not MISSING:
if not isinstance(privacy_level, StagePrivacyLevel):
raise InvalidArgument('privacy_level field must be of type PrivacyLevel')
payload['privacy_level'] = privacy_level.value
if payload:
await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason)
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the stage instance.
You must have the :attr:`~Permissions.manage_channels` permission to
use this.
Parameters
-----------
reason: :class:`str`
The reason the stage instance was deleted. Shows up on the audit log.
Raises
------
Forbidden
You do not have permissions to delete the stage instance.
HTTPException
Deleting the stage instance failed.
"""
await self._state.http.delete_stage_instance(self.channel_id, reason=reason)

File diff suppressed because it is too large Load Diff

View File

@@ -22,217 +22,16 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Literal, TYPE_CHECKING, List, Optional, Tuple, Type, Union
import unicodedata
from .mixins import Hashable from .mixins import Hashable
from .asset import Asset, AssetMixin from .asset import Asset
from .utils import cached_slot_property, find, snowflake_time, get, MISSING from .utils import snowflake_time
from .errors import InvalidData from .enums import StickerType, try_enum
from .enums import StickerType, StickerFormatType, try_enum
__all__ = ( __all__ = (
'StickerPack',
'StickerItem',
'Sticker', 'Sticker',
'StandardSticker',
'GuildSticker',
) )
if TYPE_CHECKING: class Sticker(Hashable):
import datetime
from .state import ConnectionState
from .user import User
from .guild import Guild
from .types.sticker import (
StickerPack as StickerPackPayload,
StickerItem as StickerItemPayload,
Sticker as StickerPayload,
StandardSticker as StandardStickerPayload,
GuildSticker as GuildStickerPayload,
ListPremiumStickerPacks as ListPremiumStickerPacksPayload,
EditGuildSticker,
)
class StickerPack(Hashable):
"""Represents a sticker pack.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker pack.
.. describe:: x == y
Checks if the sticker pack is equal to another sticker pack.
.. describe:: x != y
Checks if the sticker pack is not equal to another sticker pack.
Attributes
-----------
name: :class:`str`
The name of the sticker pack.
description: :class:`str`
The description of the sticker pack.
id: :class:`int`
The id of the sticker pack.
stickers: List[:class:`StandardSticker`]
The stickers of this sticker pack.
sku_id: :class:`int`
The SKU ID of the sticker pack.
cover_sticker_id: :class:`int`
The ID of the sticker used for the cover of the sticker pack.
cover_sticker: :class:`StandardSticker`
The sticker used for the cover of the sticker pack.
"""
__slots__ = (
'_state',
'id',
'stickers',
'name',
'sku_id',
'cover_sticker_id',
'cover_sticker',
'description',
'_banner',
)
def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None:
self._state: ConnectionState = state
self._from_data(data)
def _from_data(self, data: StickerPackPayload) -> None:
self.id: int = int(data['id'])
stickers = data['stickers']
self.stickers: List[StandardSticker] = [StandardSticker(state=self._state, data=sticker) for sticker in stickers]
self.name: str = data['name']
self.sku_id: int = int(data['sku_id'])
self.cover_sticker_id: int = int(data['cover_sticker_id'])
self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore
self.description: str = data['description']
self._banner: int = int(data['banner_asset_id'])
@property
def banner(self) -> Asset:
""":class:`Asset`: The banner asset of the sticker pack."""
return Asset._from_sticker_banner(self._state, self._banner)
def __repr__(self) -> str:
return f'<StickerPack id={self.id} name={self.name!r} description={self.description!r}>'
def __str__(self) -> str:
return self.name
class _StickerTag(Hashable, AssetMixin):
__slots__ = ()
id: int
format: StickerFormatType
async def read(self) -> bytes:
"""|coro|
Retrieves the content of this sticker as a :class:`bytes` object.
.. note::
Stickers that use the :attr:`StickerFormatType.lottie` format cannot be read.
Raises
------
HTTPException
Downloading the asset failed.
NotFound
The asset was deleted.
TypeError
The sticker is a lottie type.
Returns
-------
:class:`bytes`
The content of the asset.
"""
if self.format is StickerFormatType.lottie:
raise TypeError('Cannot read stickers of format "lottie".')
return await super().read()
class StickerItem(_StickerTag):
"""Represents a sticker item.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker item.
.. describe:: x == y
Checks if the sticker item is equal to another sticker item.
.. describe:: x != y
Checks if the sticker item is not equal to another sticker item.
Attributes
-----------
name: :class:`str`
The sticker's name.
id: :class:`int`
The id of the sticker.
format: :class:`StickerFormatType`
The format for the sticker's image.
url: :class:`str`
The URL for the sticker's image.
"""
__slots__ = ('_state', 'name', 'id', 'format', 'url')
def __init__(self, *, state: ConnectionState, data: StickerItemPayload):
self._state: ConnectionState = state
self.name: str = data['name']
self.id: int = int(data['id'])
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
def __repr__(self) -> str:
return f'<StickerItem id={self.id} name={self.name!r} format={self.format}>'
def __str__(self) -> str:
return self.name
async def fetch(self) -> Union[Sticker, StandardSticker, GuildSticker]:
"""|coro|
Attempts to retrieve the full sticker data of the sticker item.
Raises
--------
HTTPException
Retrieving the sticker failed.
Returns
--------
Union[:class:`StandardSticker`, :class:`GuildSticker`]
The retrieved sticker.
"""
data: StickerPayload = await self._state.http.get_sticker(self.id)
cls, _ = _sticker_factory(data['type']) # type: ignore
return cls(state=self._state, data=data)
class Sticker(_StickerTag):
"""Represents a sticker. """Represents a sticker.
.. versionadded:: 1.6 .. versionadded:: 1.6
@@ -261,263 +60,82 @@ class Sticker(_StickerTag):
The description of the sticker. The description of the sticker.
pack_id: :class:`int` pack_id: :class:`int`
The id of the sticker's pack. The id of the sticker's pack.
format: :class:`StickerFormatType` format: :class:`StickerType`
The format for the sticker's image.
url: :class:`str`
The URL for the sticker's image.
"""
__slots__ = ('_state', 'id', 'name', 'description', 'format', 'url')
def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None:
self._state: ConnectionState = state
self._from_data(data)
def _from_data(self, data: StickerPayload) -> None:
self.id: int = int(data['id'])
self.name: str = data['name']
self.description: str = data['description']
self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type'])
self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}'
def __repr__(self) -> str:
return f'<Sticker id={self.id} name={self.name!r}>'
def __str__(self) -> str:
return self.name
@property
def created_at(self) -> datetime.datetime:
""":class:`datetime.datetime`: Returns the sticker's creation time in UTC."""
return snowflake_time(self.id)
class StandardSticker(Sticker):
"""Represents a sticker that is found in a standard sticker pack.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker.
.. describe:: x == y
Checks if the sticker is equal to another sticker.
.. describe:: x != y
Checks if the sticker is not equal to another sticker.
Attributes
----------
name: :class:`str`
The sticker's name.
id: :class:`int`
The id of the sticker.
description: :class:`str`
The description of the sticker.
pack_id: :class:`int`
The id of the sticker's pack.
format: :class:`StickerFormatType`
The format for the sticker's image. The format for the sticker's image.
image: :class:`str`
The sticker's image.
tags: List[:class:`str`] tags: List[:class:`str`]
A list of tags for the sticker. A list of tags for the sticker.
sort_value: :class:`int` preview_image: Optional[:class:`str`]
The sticker's sort order within its pack. The sticker's preview asset hash.
""" """
__slots__ = ('_state', 'id', 'name', 'description', 'pack_id', 'format', 'image', 'tags', 'preview_image')
__slots__ = ('sort_value', 'pack_id', 'type', 'tags') def __init__(self, *, state, data):
self._state = state
def _from_data(self, data: StandardStickerPayload) -> None: self.id = int(data['id'])
super()._from_data(data) self.name = data['name']
self.sort_value: int = data['sort_value'] self.description = data['description']
self.pack_id: int = int(data['pack_id']) self.pack_id = int(data.get('pack_id', 0))
self.type: StickerType = StickerType.standard self.format = try_enum(StickerType, data['format_type'])
self.image = data['asset']
try: try:
self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')] self.tags = [tag.strip() for tag in data['tags'].split(',')]
except KeyError: except KeyError:
self.tags = [] self.tags = []
def __repr__(self) -> str: self.preview_image = data.get('preview_asset')
return f'<StandardSticker id={self.id} name={self.name!r} pack_id={self.pack_id}>'
async def pack(self) -> StickerPack: def __repr__(self):
"""|coro| return '<{0.__class__.__name__} id={0.id} name={0.name!r}>'.format(self)
Retrieves the sticker pack that this sticker belongs to. def __str__(self):
return self.name
Raises @property
-------- def created_at(self):
InvalidData """:class:`datetime.datetime`: Returns the sticker's creation time in UTC."""
The corresponding sticker pack was not found. return snowflake_time(self.id)
HTTPException
Retrieving the sticker pack failed. @property
def image_url(self):
"""Returns an :class:`Asset` for the sticker's image.
.. note::
This will return ``None`` if the format is ``StickerType.lottie``.
Returns Returns
-------- -------
:class:`StickerPack` Optional[:class:`Asset`]
The retrieved sticker pack. The resulting CDN asset.
""" """
data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs() return self.image_url_as()
packs = data['sticker_packs']
pack = find(lambda d: int(d['id']) == self.pack_id, packs)
if pack: def image_url_as(self, *, size=1024):
return StickerPack(state=self._state, data=pack) """Optionally returns an :class:`Asset` for the sticker's image.
raise InvalidData(f'Could not find corresponding sticker pack for {self!r}')
The size must be a power of 2 between 16 and 4096.
class GuildSticker(Sticker): .. note::
"""Represents a sticker that belongs to a guild. This will return ``None`` if the format is ``StickerType.lottie``.
.. versionadded:: 2.0
.. container:: operations
.. describe:: str(x)
Returns the name of the sticker.
.. describe:: x == y
Checks if the sticker is equal to another sticker.
.. describe:: x != y
Checks if the sticker is not equal to another sticker.
Attributes
----------
name: :class:`str`
The sticker's name.
id: :class:`int`
The id of the sticker.
description: :class:`str`
The description of the sticker.
format: :class:`StickerFormatType`
The format for the sticker's image.
available: :class:`bool`
Whether this sticker is available for use.
guild_id: :class:`int`
The ID of the guild that this sticker is from.
user: Optional[:class:`User`]
The user that created this sticker. This can only be retrieved using :meth:`Guild.fetch_sticker` and
having the :attr:`~Permissions.manage_emojis_and_stickers` permission.
emoji: :class:`str`
The name of a unicode emoji that represents this sticker.
"""
__slots__ = ('available', 'guild_id', 'user', 'emoji', 'type', '_cs_guild')
def _from_data(self, data: GuildStickerPayload) -> None:
super()._from_data(data)
self.available: bool = data['available']
self.guild_id: int = int(data['guild_id'])
user = data.get('user')
self.user: Optional[User] = self._state.store_user(user) if user else None
self.emoji: str = data['tags']
self.type: StickerType = StickerType.guild
def __repr__(self) -> str:
return f'<GuildSticker name={self.name!r} id={self.id} guild_id={self.guild_id} user={self.user!r}>'
@cached_slot_property('_cs_guild')
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild that this sticker is from.
Could be ``None`` if the bot is not in the guild.
.. versionadded:: 2.0
"""
return self._state._get_guild(self.guild_id)
async def edit(
self,
*,
name: str = MISSING,
description: str = MISSING,
emoji: str = MISSING,
reason: Optional[str] = None,
) -> GuildSticker:
"""|coro|
Edits a :class:`GuildSticker` for the guild.
Parameters Parameters
----------- -----------
name: :class:`str` size: :class:`int`
The sticker's new name. Must be at least 2 characters. The size of the image to display.
description: Optional[:class:`str`]
The sticker's new description. Can be ``None``.
emoji: :class:`str`
The name of a unicode emoji that represents the sticker's expression.
reason: :class:`str`
The reason for editing this sticker. Shows up on the audit log.
Raises Raises
------- ------
Forbidden InvalidArgument
You are not allowed to edit stickers. Invalid ``size``.
HTTPException
An error occurred editing the sticker.
Returns Returns
--------
:class:`GuildSticker`
The newly modified sticker.
"""
payload: EditGuildSticker = {}
if name is not MISSING:
payload['name'] = name
if description is not MISSING:
payload['description'] = description
if emoji is not MISSING:
try:
emoji = unicodedata.name(emoji)
except TypeError:
pass
else:
emoji = emoji.replace(' ', '_')
payload['tags'] = emoji
data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason)
return GuildSticker(state=self._state, data=data)
async def delete(self, *, reason: Optional[str] = None) -> None:
"""|coro|
Deletes the custom :class:`Sticker` from the guild.
You must have :attr:`~Permissions.manage_emojis_and_stickers` permission to
do this.
Parameters
-----------
reason: Optional[:class:`str`]
The reason for deleting this sticker. Shows up on the audit log.
Raises
------- -------
Forbidden Optional[:class:`Asset`]
You are not allowed to delete stickers. The resulting CDN asset or ``None``.
HTTPException
An error occurred deleting the sticker.
""" """
await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason) if self.format is StickerType.lottie:
return None
return Asset._from_sticker_url(self._state, self, size=size)
def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]:
value = try_enum(StickerType, sticker_type)
if value == StickerType.standard:
return StandardSticker, value
elif value == StickerType.guild:
return GuildSticker, value
else:
return Sticker, value

View File

@@ -22,29 +22,16 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from . import utils from . import utils
from .user import BaseUser from .user import BaseUser
from .asset import Asset from .asset import Asset
from .enums import TeamMembershipState, try_enum from .enums import TeamMembershipState, try_enum
from typing import TYPE_CHECKING, Optional, List
if TYPE_CHECKING:
from .state import ConnectionState
from .types.team import (
Team as TeamPayload,
TeamMember as TeamMemberPayload,
)
__all__ = ( __all__ = (
'Team', 'Team',
'TeamMember', 'TeamMember',
) )
class Team: class Team:
"""Represents an application team for a bot provided by Discord. """Represents an application team for a bot provided by Discord.
@@ -54,6 +41,8 @@ class Team:
The team ID. The team ID.
name: :class:`str` name: :class:`str`
The team name The team name
icon: Optional[:class:`str`]
The icon hash, if it exists.
owner_id: :class:`int` owner_id: :class:`int`
The team's owner ID. The team's owner ID.
members: List[:class:`TeamMember`] members: List[:class:`TeamMember`]
@@ -61,34 +50,61 @@ class Team:
.. versionadded:: 1.3 .. versionadded:: 1.3
""" """
__slots__ = ('_state', 'id', 'name', 'icon', 'owner_id', 'members')
__slots__ = ('_state', 'id', 'name', '_icon', 'owner_id', 'members') def __init__(self, state, data):
self._state = state
def __init__(self, state: ConnectionState, data: TeamPayload): self.id = utils._get_as_snowflake(data, 'id')
self._state: ConnectionState = state self.name = data['name']
self.icon = data['icon']
self.owner_id = utils._get_as_snowflake(data, 'owner_user_id')
self.members = [TeamMember(self, self._state, member) for member in data['members']]
self.id: int = int(data['id']) def __repr__(self):
self.name: str = data['name'] return '<{0.__class__.__name__} id={0.id} name={0.name}>'.format(self)
self._icon: Optional[str] = data['icon']
self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_user_id')
self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data['members']]
def __repr__(self) -> str:
return f'<{self.__class__.__name__} id={self.id} name={self.name}>'
@property @property
def icon(self) -> Optional[Asset]: def icon_url(self):
"""Optional[:class:`.Asset`]: Retrieves the team's icon asset, if any.""" """:class:`.Asset`: Retrieves the team's icon asset.
if self._icon is None:
return None This is equivalent to calling :meth:`icon_url_as` with
return Asset._from_icon(self._state, self.id, self._icon, path='team') the default parameters ('webp' format and a size of 1024).
"""
return self.icon_url_as()
def icon_url_as(self, *, format='webp', size=1024):
"""Returns an :class:`Asset` for the icon the team has.
The format must be one of 'webp', 'jpeg', 'jpg' or 'png'.
The size must be a power of 2 between 16 and 4096.
.. versionadded:: 2.0
Parameters
-----------
format: :class:`str`
The format to attempt to convert the icon to. Defaults to 'webp'.
size: :class:`int`
The size of the image to display.
Raises
------
InvalidArgument
Bad image format passed to ``format`` or invalid ``size``.
Returns
--------
:class:`Asset`
The resulting CDN asset.
"""
return Asset._from_icon(self._state, self, 'team', format=format, size=size)
@property @property
def owner(self) -> Optional[TeamMember]: def owner(self):
"""Optional[:class:`TeamMember`]: The team's owner.""" """Optional[:class:`TeamMember`]: The team's owner."""
return utils.get(self.members, id=self.owner_id) return utils.get(self.members, id=self.owner_id)
class TeamMember(BaseUser): class TeamMember(BaseUser):
"""Represents a team member in a team. """Represents a team member in a team.
@@ -129,17 +145,14 @@ class TeamMember(BaseUser):
membership_state: :class:`TeamMembershipState` membership_state: :class:`TeamMembershipState`
The membership state of the member (e.g. invited or accepted) The membership state of the member (e.g. invited or accepted)
""" """
__slots__ = BaseUser.__slots__ + ('team', 'membership_state', 'permissions')
__slots__ = ('team', 'membership_state', 'permissions') def __init__(self, team, state, data):
self.team = team
def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload): self.membership_state = try_enum(TeamMembershipState, data['membership_state'])
self.team: Team = team self.permissions = data['permissions']
self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data['membership_state'])
self.permissions: List[str] = data['permissions']
super().__init__(state=state, data=data['user']) super().__init__(state=state, data=data['user'])
def __repr__(self) -> str: def __repr__(self):
return ( return '<{0.__class__.__name__} id={0.id} name={0.name!r} ' \
f'<{self.__class__.__name__} id={self.id} name={self.name!r} ' 'discriminator={0.discriminator!r} membership_state={0.membership_state!r}>'.format(self)
f'discriminator={self.discriminator!r} membership_state={self.membership_state!r}>'
)

View File

@@ -22,10 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data
from typing import Any, Optional, TYPE_CHECKING
from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING
from .enums import VoiceRegion from .enums import VoiceRegion
from .guild import Guild from .guild import Guild
@@ -33,20 +30,12 @@ __all__ = (
'Template', 'Template',
) )
if TYPE_CHECKING:
import datetime
from .types.template import Template as TemplatePayload
from .state import ConnectionState
from .user import User
class _FriendlyHttpAttributeErrorHelper: class _FriendlyHttpAttributeErrorHelper:
__slots__ = () __slots__ = ()
def __getattr__(self, attr): def __getattr__(self, attr):
raise AttributeError('PartialTemplateState does not support http methods.') raise AttributeError('PartialTemplateState does not support http methods.')
class _PartialTemplateState: class _PartialTemplateState:
def __init__(self, *, state): def __init__(self, *, state):
self.__state = state self.__state = state
@@ -77,16 +66,12 @@ class _PartialTemplateState:
def _get_message(self, id): def _get_message(self, id):
return None return None
def _get_guild(self, id): async def query_members(self, **kwargs):
return self.__state._get_guild(id)
async def query_members(self, **kwargs: Any):
return [] return []
def __getattr__(self, attr): def __getattr__(self, attr):
raise AttributeError(f'PartialTemplateState does not support {attr!r}.') raise AttributeError(f'PartialTemplateState does not support {attr!r}.')
class Template: class Template:
"""Represents a Discord template. """Represents a Discord template.
@@ -111,62 +96,40 @@ class Template:
This is referred to as "last synced" in the official Discord client. This is referred to as "last synced" in the official Discord client.
source_guild: :class:`Guild` source_guild: :class:`Guild`
The source guild. The source guild.
is_dirty: Optional[:class:`bool`]
Whether the template has unsynced changes.
.. versionadded:: 2.0
""" """
__slots__ = ( def __init__(self, *, state, data):
'code',
'uses',
'name',
'description',
'creator',
'created_at',
'updated_at',
'source_guild',
'is_dirty',
'_state',
)
def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None:
self._state = state self._state = state
self._store(data) self._store(data)
def _store(self, data: TemplatePayload) -> None: def _store(self, data):
self.code: str = data['code'] self.code = data['code']
self.uses: int = data['usage_count'] self.uses = data['usage_count']
self.name: str = data['name'] self.name = data['name']
self.description: Optional[str] = data['description'] self.description = data['description']
creator_data = data.get('creator') creator_data = data.get('creator')
self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data) self.creator = None if creator_data is None else self._state.store_user(creator_data)
self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at')) self.created_at = parse_time(data.get('created_at'))
self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at')) self.updated_at = parse_time(data.get('updated_at'))
guild_id = int(data['source_guild_id']) id = _get_as_snowflake(data, 'source_guild_id')
guild: Optional[Guild] = self._state._get_guild(guild_id)
guild = self._state._get_guild(id)
self.source_guild: Guild
if guild is None: if guild is None:
source_serialised = data['serialized_source_guild'] source_serialised = data['serialized_source_guild']
source_serialised['id'] = guild_id source_serialised['id'] = id
state = _PartialTemplateState(state=self._state) state = _PartialTemplateState(state=self._state)
# Guild expects a ConnectionState, we're passing a _PartialTemplateState guild = Guild(data=source_serialised, state=state)
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
else:
self.source_guild = guild
self.is_dirty: Optional[bool] = data.get('is_dirty', None) self.source_guild = guild
def __repr__(self) -> str: def __repr__(self):
return ( return '<Template code={0.code!r} uses={0.uses} name={0.name!r}' \
f'<Template code={self.code!r} uses={self.uses} name={self.name!r}' ' creator={0.creator!r} source_guild={0.source_guild!r}>'.format(self)
f' creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>'
)
async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild: async def create_guild(self, name, region=None, icon=None):
"""|coro| """|coro|
Creates a :class:`.Guild` using the template. Creates a :class:`.Guild` using the template.
@@ -206,7 +169,7 @@ class Template:
data = await self._state.http.create_from_template(self.code, name, region_value, icon) data = await self._state.http.create_from_template(self.code, name, region_value, icon)
return Guild(data=data, state=self._state) return Guild(data=data, state=self._state)
async def sync(self) -> Template: async def sync(self):
"""|coro| """|coro|
Sync the template to the guild's current state. Sync the template to the guild's current state.
@@ -216,9 +179,6 @@ class Template:
.. versionadded:: 1.7 .. versionadded:: 1.7
.. versionchanged:: 2.0
The template is no longer edited in-place, instead it is returned.
Raises Raises
------- -------
HTTPException HTTPException
@@ -227,22 +187,12 @@ class Template:
You don't have permissions to edit the template. You don't have permissions to edit the template.
NotFound NotFound
This template does not exist. This template does not exist.
Returns
--------
:class:`Template`
The newly edited template.
""" """
data = await self._state.http.sync_template(self.source_guild.id, self.code) data = await self._state.http.sync_template(self.source_guild.id, self.code)
return Template(state=self._state, data=data) self._store(data)
async def edit( async def edit(self, **kwargs):
self,
*,
name: str = MISSING,
description: Optional[str] = MISSING,
) -> Template:
"""|coro| """|coro|
Edit the template metadata. Edit the template metadata.
@@ -252,15 +202,12 @@ class Template:
.. versionadded:: 1.7 .. versionadded:: 1.7
.. versionchanged:: 2.0
The template is no longer edited in-place, instead it is returned.
Parameters Parameters
------------ ------------
name: :class:`str` name: Optional[:class:`str`]
The template's new name. The template's new name.
description: Optional[:class:`str`] description: Optional[:class:`str`]
The template's new description. The template's description.
Raises Raises
------- -------
@@ -270,23 +217,11 @@ class Template:
You don't have permissions to edit the template. You don't have permissions to edit the template.
NotFound NotFound
This template does not exist. This template does not exist.
Returns
--------
:class:`Template`
The newly edited template.
""" """
payload = {} data = await self._state.http.edit_template(self.source_guild.id, self.code, kwargs)
self._store(data)
if name is not MISSING: async def delete(self):
payload['name'] = name
if description is not MISSING:
payload['description'] = description
data = await self._state.http.edit_template(self.source_guild.id, self.code, payload)
return Template(state=self._state, data=data)
async def delete(self) -> None:
"""|coro| """|coro|
Delete the template. Delete the template.
@@ -306,11 +241,3 @@ class Template:
This template does not exist. This template does not exist.
""" """
await self._state.http.delete_template(self.source_guild.id, self.code) await self._state.http.delete_template(self.source_guild.id, self.code)
@property
def url(self) -> str:
""":class:`str`: The template url.
.. versionadded:: 2.0
"""
return f'https://discord.new/{self.code}'

View File

@@ -1,802 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Callable, Dict, Iterable, List, Optional, Union, TYPE_CHECKING
import time
import asyncio
from .mixins import Hashable
from .abc import Messageable
from .enums import ChannelType, try_enum
from .errors import ClientException
from .utils import MISSING, parse_time, _get_as_snowflake
__all__ = (
'Thread',
'ThreadMember',
)
if TYPE_CHECKING:
from .types.threads import (
Thread as ThreadPayload,
ThreadMember as ThreadMemberPayload,
ThreadMetadata,
ThreadArchiveDuration,
)
from .types.snowflake import SnowflakeList
from .guild import Guild
from .channel import TextChannel, CategoryChannel
from .member import Member
from .message import Message, PartialMessage
from .abc import Snowflake, SnowflakeTime
from .role import Role
from .permissions import Permissions
from .state import ConnectionState
class Thread(Messageable, Hashable):
"""Represents a Discord thread.
.. container:: operations
.. describe:: x == y
Checks if two threads are equal.
.. describe:: x != y
Checks if two threads are not equal.
.. describe:: hash(x)
Returns the thread's hash.
.. describe:: str(x)
Returns the thread's name.
.. versionadded:: 2.0
Attributes
-----------
name: :class:`str`
The thread name.
guild: :class:`Guild`
The guild the thread belongs to.
id: :class:`int`
The thread ID.
parent_id: :class:`int`
The parent :class:`TextChannel` ID this thread belongs to.
owner_id: :class:`int`
The user's ID that created this thread.
last_message_id: Optional[:class:`int`]
The last message ID of the message sent to this thread. It may
*not* point to an existing or valid message.
slowmode_delay: :class:`int`
The number of seconds a member must wait between sending messages
in this thread. A value of `0` denotes that it is disabled.
Bots and users with :attr:`~Permissions.manage_channels` or
:attr:`~Permissions.manage_messages` bypass slowmode.
message_count: :class:`int`
An approximate number of messages in this thread. This caps at 50.
member_count: :class:`int`
An approximate number of members in this thread. This caps at 50.
me: Optional[:class:`ThreadMember`]
A thread member representing yourself, if you've joined the thread.
This could not be available.
archived: :class:`bool`
Whether the thread is archived.
locked: :class:`bool`
Whether the thread is locked.
invitable: :class:`bool`
Whether non-moderators can add other non-moderators to this thread.
This is always ``True`` for public threads.
archiver_id: Optional[:class:`int`]
The user's ID that archived this thread.
auto_archive_duration: :class:`int`
The duration in minutes until the thread is automatically archived due to inactivity.
Usually a value of 60, 1440, 4320 and 10080.
archive_timestamp: :class:`datetime.datetime`
An aware timestamp of when the thread's archived status was last updated in UTC.
"""
__slots__ = (
'name',
'id',
'guild',
'_type',
'_state',
'_members',
'owner_id',
'parent_id',
'last_message_id',
'message_count',
'member_count',
'slowmode_delay',
'me',
'locked',
'archived',
'invitable',
'archiver_id',
'auto_archive_duration',
'archive_timestamp',
)
def __init__(self, *, guild: Guild, state: ConnectionState, data: ThreadPayload):
self._state: ConnectionState = state
self.guild = guild
self._members: Dict[int, ThreadMember] = {}
self._from_data(data)
async def _get_channel(self):
return self
def __repr__(self) -> str:
return (
f'<Thread id={self.id!r} name={self.name!r} parent={self.parent}'
f' owner_id={self.owner_id!r} locked={self.locked} archived={self.archived}>'
)
def __str__(self) -> str:
return self.name
def _from_data(self, data: ThreadPayload):
self.id = int(data['id'])
self.parent_id = int(data['parent_id'])
self.owner_id = int(data['owner_id'])
self.name = data['name']
self._type = try_enum(ChannelType, data['type'])
self.last_message_id = _get_as_snowflake(data, 'last_message_id')
self.slowmode_delay = data.get('rate_limit_per_user', 0)
self.message_count = data['message_count']
self.member_count = data['member_count']
self._unroll_metadata(data['thread_metadata'])
try:
member = data['member']
except KeyError:
self.me = None
else:
self.me = ThreadMember(self, member)
def _unroll_metadata(self, data: ThreadMetadata):
self.archived = data['archived']
self.archiver_id = _get_as_snowflake(data, 'archiver_id')
self.auto_archive_duration = data['auto_archive_duration']
self.archive_timestamp = parse_time(data['archive_timestamp'])
self.locked = data.get('locked', False)
self.invitable = data.get('invitable', True)
def _update(self, data):
try:
self.name = data['name']
except KeyError:
pass
self.slowmode_delay = data.get('rate_limit_per_user', 0)
try:
self._unroll_metadata(data['thread_metadata'])
except KeyError:
pass
@property
def type(self) -> ChannelType:
""":class:`ChannelType`: The channel's Discord type."""
return self._type
@property
def parent(self) -> Optional[TextChannel]:
"""Optional[:class:`TextChannel`]: The parent channel this thread belongs to."""
return self.guild.get_channel(self.parent_id) # type: ignore
@property
def owner(self) -> Optional[Member]:
"""Optional[:class:`Member`]: The member this thread belongs to."""
return self.guild.get_member(self.owner_id)
@property
def mention(self) -> str:
""":class:`str`: The string that allows you to mention the thread."""
return f'<#{self.id}>'
@property
def members(self) -> List[ThreadMember]:
"""List[:class:`ThreadMember`]: A list of thread members in this thread.
This requires :attr:`Intents.members` to be properly filled. Most of the time however,
this data is not provided by the gateway and a call to :meth:`fetch_members` is
needed.
"""
return list(self._members.values())
@property
def last_message(self) -> Optional[Message]:
"""Fetches the last message from this channel in cache.
The message might not be valid or point to an existing message.
.. admonition:: Reliable Fetching
:class: helpful
For a slightly more reliable method of fetching the
last message, consider using either :meth:`history`
or :meth:`fetch_message` with the :attr:`last_message_id`
attribute.
Returns
---------
Optional[:class:`Message`]
The last message in this channel or ``None`` if not found.
"""
return self._state._get_message(self.last_message_id) if self.last_message_id else None
@property
def category(self) -> Optional[CategoryChannel]:
"""The category channel the parent channel belongs to, if applicable.
Raises
-------
ClientException
The parent channel was not cached and returned ``None``.
Returns
-------
Optional[:class:`CategoryChannel`]
The parent channel's category.
"""
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
return parent.category
@property
def category_id(self) -> Optional[int]:
"""The category channel ID the parent channel belongs to, if applicable.
Raises
-------
ClientException
The parent channel was not cached and returned ``None``.
Returns
-------
Optional[:class:`int`]
The parent channel's category ID.
"""
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
return parent.category_id
def is_private(self) -> bool:
""":class:`bool`: Whether the thread is a private thread.
A private thread is only viewable by those that have been explicitly
invited or have :attr:`~.Permissions.manage_threads`.
"""
return self._type is ChannelType.private_thread
def is_news(self) -> bool:
""":class:`bool`: Whether the thread is a news thread.
A news thread is a thread that has a parent that is a news channel,
i.e. :meth:`.TextChannel.is_news` is ``True``.
"""
return self._type is ChannelType.news_thread
def is_nsfw(self) -> bool:
""":class:`bool`: Whether the thread is NSFW or not.
An NSFW thread is a thread that has a parent that is an NSFW channel,
i.e. :meth:`.TextChannel.is_nsfw` is ``True``.
"""
parent = self.parent
return parent is not None and parent.is_nsfw()
def permissions_for(self, obj: Union[Member, Role], /) -> Permissions:
"""Handles permission resolution for the :class:`~discord.Member`
or :class:`~discord.Role`.
Since threads do not have their own permissions, they inherit them
from the parent channel. This is a convenience method for
calling :meth:`~discord.TextChannel.permissions_for` on the
parent channel.
Parameters
----------
obj: Union[:class:`~discord.Member`, :class:`~discord.Role`]
The object to resolve permissions for. This could be either
a member or a role. If it's a role then member overwrites
are not computed.
Raises
-------
ClientException
The parent channel was not cached and returned ``None``
Returns
-------
:class:`~discord.Permissions`
The resolved permissions for the member or role.
"""
parent = self.parent
if parent is None:
raise ClientException('Parent channel not found')
return parent.permissions_for(obj)
async def delete_messages(self, messages: Iterable[Snowflake]) -> None:
"""|coro|
Deletes a list of messages. This is similar to :meth:`Message.delete`
except it bulk deletes multiple messages.
As a special case, if the number of messages is 0, then nothing
is done. If the number of messages is 1 then single message
delete is done. If it's more than two, then bulk delete is used.
You cannot bulk delete more than 100 messages or messages that
are older than 14 days old.
You must have the :attr:`~Permissions.manage_messages` permission to
use this.
Usable only by bot accounts.
Parameters
-----------
messages: Iterable[:class:`abc.Snowflake`]
An iterable of messages denoting which ones to bulk delete.
Raises
------
ClientException
The number of messages to delete was more than 100.
Forbidden
You do not have proper permissions to delete the messages or
you're not using a bot account.
NotFound
If single delete, then the message was already deleted.
HTTPException
Deleting the messages failed.
"""
if not isinstance(messages, (list, tuple)):
messages = list(messages)
if len(messages) == 0:
return # do nothing
if len(messages) == 1:
message_id = messages[0].id
await self._state.http.delete_message(self.id, message_id)
return
if len(messages) > 100:
raise ClientException('Can only bulk delete messages up to 100 messages')
message_ids: SnowflakeList = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)
async def purge(
self,
*,
limit: Optional[int] = 100,
check: Callable[[Message], bool] = MISSING,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
around: Optional[SnowflakeTime] = None,
oldest_first: Optional[bool] = False,
bulk: bool = True,
) -> List[Message]:
"""|coro|
Purges a list of messages that meet the criteria given by the predicate
``check``. If a ``check`` is not provided then all messages are deleted
without discrimination.
You must have the :attr:`~Permissions.manage_messages` permission to
delete messages even if they are your own (unless you are a user
account). The :attr:`~Permissions.read_message_history` permission is
also needed to retrieve message history.
Examples
---------
Deleting bot's messages ::
def is_me(m):
return m.author == client.user
deleted = await thread.purge(limit=100, check=is_me)
await thread.send(f'Deleted {len(deleted)} message(s)')
Parameters
-----------
limit: Optional[:class:`int`]
The number of messages to search through. This is not the number
of messages that will be deleted, though it can be.
check: Callable[[:class:`Message`], :class:`bool`]
The function used to check if a message should be deleted.
It must take a :class:`Message` as its sole parameter.
before: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Same as ``before`` in :meth:`history`.
after: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Same as ``after`` in :meth:`history`.
around: Optional[Union[:class:`abc.Snowflake`, :class:`datetime.datetime`]]
Same as ``around`` in :meth:`history`.
oldest_first: Optional[:class:`bool`]
Same as ``oldest_first`` in :meth:`history`.
bulk: :class:`bool`
If ``True``, use bulk delete. Setting this to ``False`` is useful for mass-deleting
a bot's own messages without :attr:`Permissions.manage_messages`. When ``True``, will
fall back to single delete if messages are older than two weeks.
Raises
-------
Forbidden
You do not have proper permissions to do the actions required.
HTTPException
Purging the messages failed.
Returns
--------
List[:class:`.Message`]
The list of messages that were deleted.
"""
if check is MISSING:
check = lambda m: True
iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around)
ret: List[Message] = []
count = 0
minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22
async def _single_delete_strategy(messages: Iterable[Message]):
for m in messages:
await m.delete()
strategy = self.delete_messages if bulk else _single_delete_strategy
async for message in iterator:
if count == 100:
to_delete = ret[-100:]
await strategy(to_delete)
count = 0
await asyncio.sleep(1)
if not check(message):
continue
if message.id < minimum_time:
# older than 14 days old
if count == 1:
await ret[-1].delete()
elif count >= 2:
to_delete = ret[-count:]
await strategy(to_delete)
count = 0
strategy = _single_delete_strategy
count += 1
ret.append(message)
# SOme messages remaining to poll
if count >= 2:
# more than 2 messages -> bulk delete
to_delete = ret[-count:]
await strategy(to_delete)
elif count == 1:
# delete a single message
await ret[-1].delete()
return ret
async def edit(
self,
*,
name: str = MISSING,
archived: bool = MISSING,
locked: bool = MISSING,
invitable: bool = MISSING,
slowmode_delay: int = MISSING,
auto_archive_duration: ThreadArchiveDuration = MISSING,
) -> Thread:
"""|coro|
Edits the thread.
Editing the thread requires :attr:`.Permissions.manage_threads`. The thread
creator can also edit ``name``, ``archived`` or ``auto_archive_duration``.
Note that if the thread is locked then only those with :attr:`.Permissions.manage_threads`
can unarchive a thread.
The thread must be unarchived to be edited.
Parameters
------------
name: :class:`str`
The new name of the thread.
archived: :class:`bool`
Whether to archive the thread or not.
locked: :class:`bool`
Whether to lock the thread or not.
invitable: :class:`bool`
Whether non-moderators can add other non-moderators to this thread.
Only available for private threads.
auto_archive_duration: :class:`int`
The new duration in minutes before a thread is automatically archived for inactivity.
Must be one of ``60``, ``1440``, ``4320``, or ``10080``.
slowmode_delay: :class:`int`
Specifies the slowmode rate limit for user in this thread, in seconds.
A value of ``0`` disables slowmode. The maximum value possible is ``21600``.
Raises
-------
Forbidden
You do not have permissions to edit the thread.
HTTPException
Editing the thread failed.
Returns
--------
:class:`Thread`
The newly edited thread.
"""
payload = {}
if name is not MISSING:
payload['name'] = str(name)
if archived is not MISSING:
payload['archived'] = archived
if auto_archive_duration is not MISSING:
payload['auto_archive_duration'] = auto_archive_duration
if locked is not MISSING:
payload['locked'] = locked
if invitable is not MISSING:
payload['invitable'] = invitable
if slowmode_delay is not MISSING:
payload['rate_limit_per_user'] = slowmode_delay
data = await self._state.http.edit_channel(self.id, **payload)
# The data payload will always be a Thread payload
return Thread(data=data, state=self._state, guild=self.guild) # type: ignore
async def join(self):
"""|coro|
Joins this thread.
You must have :attr:`~Permissions.send_messages_in_threads` to join a thread.
If the thread is private, :attr:`~Permissions.manage_threads` is also needed.
Raises
-------
Forbidden
You do not have permissions to join the thread.
HTTPException
Joining the thread failed.
"""
await self._state.http.join_thread(self.id)
async def leave(self):
"""|coro|
Leaves this thread.
Raises
-------
HTTPException
Leaving the thread failed.
"""
await self._state.http.leave_thread(self.id)
async def add_user(self, user: Snowflake):
"""|coro|
Adds a user to this thread.
You must have :attr:`~Permissions.send_messages` and :attr:`~Permissions.use_threads`
to add a user to a public thread. If the thread is private then :attr:`~Permissions.send_messages`
and either :attr:`~Permissions.use_private_threads` or :attr:`~Permissions.manage_messages`
is required to add a user to the thread.
Parameters
-----------
user: :class:`abc.Snowflake`
The user to add to the thread.
Raises
-------
Forbidden
You do not have permissions to add the user to the thread.
HTTPException
Adding the user to the thread failed.
"""
await self._state.http.add_user_to_thread(self.id, user.id)
async def remove_user(self, user: Snowflake):
"""|coro|
Removes a user from this thread.
You must have :attr:`~Permissions.manage_threads` or be the creator of the thread to remove a user.
Parameters
-----------
user: :class:`abc.Snowflake`
The user to add to the thread.
Raises
-------
Forbidden
You do not have permissions to remove the user from the thread.
HTTPException
Removing the user from the thread failed.
"""
await self._state.http.remove_user_from_thread(self.id, user.id)
async def fetch_members(self) -> List[ThreadMember]:
"""|coro|
Retrieves all :class:`ThreadMember` that are in this thread.
This requires :attr:`Intents.members` to get information about members
other than yourself.
Raises
-------
HTTPException
Retrieving the members failed.
Returns
--------
List[:class:`ThreadMember`]
All thread members in the thread.
"""
members = await self._state.http.get_thread_members(self.id)
return [ThreadMember(parent=self, data=data) for data in members]
async def delete(self):
"""|coro|
Deletes this thread.
You must have :attr:`~Permissions.manage_threads` to delete threads.
Raises
-------
Forbidden
You do not have permissions to delete this thread.
HTTPException
Deleting the thread failed.
"""
await self._state.http.delete_channel(self.id)
def get_partial_message(self, message_id: int, /) -> PartialMessage:
"""Creates a :class:`PartialMessage` from the message ID.
This is useful if you want to work with a message and only have its ID without
doing an unnecessary API call.
.. versionadded:: 2.0
Parameters
------------
message_id: :class:`int`
The message ID to create a partial message for.
Returns
---------
:class:`PartialMessage`
The partial message.
"""
from .message import PartialMessage
return PartialMessage(channel=self, id=message_id)
def _add_member(self, member: ThreadMember) -> None:
self._members[member.id] = member
def _pop_member(self, member_id: int) -> Optional[ThreadMember]:
return self._members.pop(member_id, None)
class ThreadMember(Hashable):
"""Represents a Discord thread member.
.. container:: operations
.. describe:: x == y
Checks if two thread members are equal.
.. describe:: x != y
Checks if two thread members are not equal.
.. describe:: hash(x)
Returns the thread member's hash.
.. describe:: str(x)
Returns the thread member's name.
.. versionadded:: 2.0
Attributes
-----------
id: :class:`int`
The thread member's ID.
thread_id: :class:`int`
The thread's ID.
joined_at: :class:`datetime.datetime`
The time the member joined the thread in UTC.
"""
__slots__ = (
'id',
'thread_id',
'joined_at',
'flags',
'_state',
'parent',
)
def __init__(self, parent: Thread, data: ThreadMemberPayload):
self.parent = parent
self._state = parent._state
self._from_data(data)
def __repr__(self) -> str:
return f'<ThreadMember id={self.id} thread_id={self.thread_id} joined_at={self.joined_at!r}>'
def _from_data(self, data: ThreadMemberPayload):
try:
self.id = int(data['user_id'])
except KeyError:
assert self._state.self_id is not None
self.id = self._state.self_id
try:
self.thread_id = int(data['id'])
except KeyError:
self.thread_id = self.parent.id
self.joined_at = parse_time(data['join_timestamp'])
self.flags = data['flags']
@property
def thread(self) -> Thread:
""":class:`Thread`: The thread this member belongs to."""
return self.parent

View File

@@ -1,114 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, Optional, TypedDict
from .user import PartialUser
from .snowflake import Snowflake
StatusType = Literal['idle', 'dnd', 'online', 'offline']
class PartialPresenceUpdate(TypedDict):
user: PartialUser
guild_id: Snowflake
status: StatusType
activities: List[Activity]
client_status: ClientStatus
class ClientStatus(TypedDict, total=False):
desktop: str
mobile: str
web: str
class ActivityTimestamps(TypedDict, total=False):
start: int
end: int
class ActivityParty(TypedDict, total=False):
id: str
size: List[int]
class ActivityAssets(TypedDict, total=False):
large_image: str
large_text: str
small_image: str
small_text: str
class ActivitySecrets(TypedDict, total=False):
join: str
spectate: str
match: str
class _ActivityEmojiOptional(TypedDict, total=False):
id: Snowflake
animated: bool
class ActivityEmoji(_ActivityEmojiOptional):
name: str
class ActivityButton(TypedDict):
label: str
url: str
class _SendableActivityOptional(TypedDict, total=False):
url: Optional[str]
ActivityType = Literal[0, 1, 2, 4, 5]
class SendableActivity(_SendableActivityOptional):
name: str
type: ActivityType
class _BaseActivity(SendableActivity):
created_at: int
class Activity(_BaseActivity, total=False):
state: Optional[str]
details: Optional[str]
timestamps: ActivityTimestamps
assets: ActivityAssets
party: ActivityParty
application_id: Snowflake
flags: int
emoji: Optional[ActivityEmoji]
secrets: ActivitySecrets
session_id: Optional[str]
instance: bool
buttons: List[ActivityButton]

View File

@@ -1,67 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TypedDict, List, Optional
from .user import User
from .team import Team
from .snowflake import Snowflake
class BaseAppInfo(TypedDict):
id: Snowflake
name: str
verify_key: str
icon: Optional[str]
summary: str
description: str
class _AppInfoOptional(TypedDict, total=False):
team: Team
guild_id: Snowflake
primary_sku_id: Snowflake
slug: str
terms_of_service_url: str
privacy_policy_url: str
hook: bool
max_participants: int
class AppInfo(BaseAppInfo, _AppInfoOptional):
rpc_origins: List[str]
owner: User
bot_public: bool
bot_require_code_grant: bool
class _PartialAppInfoOptional(TypedDict, total=False):
rpc_origins: List[str]
cover_image: str
hook: bool
terms_of_service_url: str
privacy_policy_url: str
max_participants: int
flags: int
class PartialAppInfo(_PartialAppInfoOptional, BaseAppInfo):
pass

View File

@@ -1,257 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, Optional, TypedDict, Union
from .webhook import Webhook
from .guild import MFALevel, VerificationLevel, ExplicitContentFilterLevel, DefaultMessageNotificationLevel
from .integration import IntegrationExpireBehavior, PartialIntegration
from .user import User
from .snowflake import Snowflake
from .role import Role
from .channel import ChannelType, VideoQualityMode, PermissionOverwrite
from .threads import Thread
AuditLogEvent = Literal[
1,
10,
11,
12,
13,
14,
15,
20,
21,
22,
23,
24,
25,
26,
27,
28,
30,
31,
32,
40,
41,
42,
50,
51,
52,
60,
61,
62,
72,
73,
74,
75,
80,
81,
82,
83,
84,
85,
90,
91,
92,
110,
111,
112,
]
class _AuditLogChange_Str(TypedDict):
key: Literal[
'name', 'description', 'preferred_locale', 'vanity_url_code', 'topic', 'code', 'allow', 'deny', 'permissions', 'tags'
]
new_value: str
old_value: str
class _AuditLogChange_AssetHash(TypedDict):
key: Literal['icon_hash', 'splash_hash', 'discovery_splash_hash', 'banner_hash', 'avatar_hash', 'asset']
new_value: str
old_value: str
class _AuditLogChange_Snowflake(TypedDict):
key: Literal[
'id',
'owner_id',
'afk_channel_id',
'rules_channel_id',
'public_updates_channel_id',
'widget_channel_id',
'system_channel_id',
'application_id',
'channel_id',
'inviter_id',
'guild_id',
]
new_value: Snowflake
old_value: Snowflake
class _AuditLogChange_Bool(TypedDict):
key: Literal[
'widget_enabled',
'nsfw',
'hoist',
'mentionable',
'temporary',
'deaf',
'mute',
'nick',
'enabled_emoticons',
'region',
'rtc_region',
'available',
'archived',
'locked',
]
new_value: bool
old_value: bool
class _AuditLogChange_Int(TypedDict):
key: Literal[
'afk_timeout',
'prune_delete_days',
'position',
'bitrate',
'rate_limit_per_user',
'color',
'max_uses',
'max_age',
'user_limit',
'auto_archive_duration',
'default_auto_archive_duration',
]
new_value: int
old_value: int
class _AuditLogChange_ListRole(TypedDict):
key: Literal['$add', '$remove']
new_value: List[Role]
old_value: List[Role]
class _AuditLogChange_MFALevel(TypedDict):
key: Literal['mfa_level']
new_value: MFALevel
old_value: MFALevel
class _AuditLogChange_VerificationLevel(TypedDict):
key: Literal['verification_level']
new_value: VerificationLevel
old_value: VerificationLevel
class _AuditLogChange_ExplicitContentFilter(TypedDict):
key: Literal['explicit_content_filter']
new_value: ExplicitContentFilterLevel
old_value: ExplicitContentFilterLevel
class _AuditLogChange_DefaultMessageNotificationLevel(TypedDict):
key: Literal['default_message_notifications']
new_value: DefaultMessageNotificationLevel
old_value: DefaultMessageNotificationLevel
class _AuditLogChange_ChannelType(TypedDict):
key: Literal['type']
new_value: ChannelType
old_value: ChannelType
class _AuditLogChange_IntegrationExpireBehaviour(TypedDict):
key: Literal['expire_behavior']
new_value: IntegrationExpireBehavior
old_value: IntegrationExpireBehavior
class _AuditLogChange_VideoQualityMode(TypedDict):
key: Literal['video_quality_mode']
new_value: VideoQualityMode
old_value: VideoQualityMode
class _AuditLogChange_Overwrites(TypedDict):
key: Literal['permission_overwrites']
new_value: List[PermissionOverwrite]
old_value: List[PermissionOverwrite]
AuditLogChange = Union[
_AuditLogChange_Str,
_AuditLogChange_AssetHash,
_AuditLogChange_Snowflake,
_AuditLogChange_Int,
_AuditLogChange_Bool,
_AuditLogChange_ListRole,
_AuditLogChange_MFALevel,
_AuditLogChange_VerificationLevel,
_AuditLogChange_ExplicitContentFilter,
_AuditLogChange_DefaultMessageNotificationLevel,
_AuditLogChange_ChannelType,
_AuditLogChange_IntegrationExpireBehaviour,
_AuditLogChange_VideoQualityMode,
_AuditLogChange_Overwrites,
]
class AuditEntryInfo(TypedDict):
delete_member_days: str
members_removed: str
channel_id: Snowflake
message_id: Snowflake
count: str
id: Snowflake
type: Literal['0', '1']
role_name: str
class _AuditLogEntryOptional(TypedDict, total=False):
changes: List[AuditLogChange]
options: AuditEntryInfo
reason: str
class AuditLogEntry(_AuditLogEntryOptional):
target_id: Optional[str]
user_id: Optional[Snowflake]
id: Snowflake
action_type: AuditLogEvent
class AuditLog(TypedDict):
webhooks: List[Webhook]
users: List[User]
audit_log_entries: List[AuditLogEntry]
integrations: List[PartialIntegration]
threads: List[Thread]

View File

@@ -22,31 +22,58 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from typing import List, Literal, Optional, TypedDict, Union
from .user import PartialUser from .user import PartialUser
from .snowflake import Snowflake from .snowflake import Snowflake
from .threads import ThreadMetadata, ThreadMember, ThreadArchiveDuration from typing import List, Literal, Optional, TypedDict
OverwriteType = Literal[0, 1]
class PermissionOverwrite(TypedDict): class PermissionOverwrite(TypedDict):
id: Snowflake id: Snowflake
type: OverwriteType type: Literal[0, 1]
allow: str allow: str
deny: str deny: str
ChannelType = Literal[0, 1, 2, 3, 4, 5, 6, 10, 11, 12, 13] ChannelType = Literal[0, 1, 2, 3, 4, 5, 6, 13]
class _BaseChannel(TypedDict): class PartialChannel(TypedDict):
id: Snowflake id: str
type: ChannelType
name: str name: str
class _BaseGuildChannel(_BaseChannel): class _TextChannelOptional(PartialChannel, total=False):
topic: str
last_message_id: Optional[Snowflake]
last_pin_timestamp: int
rate_limit_per_user: int
class _VoiceChannelOptional(PartialChannel, total=False):
rtc_region: Optional[str]
bitrate: int
user_limit: int
class _CategoryChannelOptional(PartialChannel, total=False):
...
class _StoreChannelOptional(PartialChannel, total=False):
...
class _StageChannelOptional(PartialChannel, total=False):
rtc_region: Optional[str]
bitrate: int
user_limit: int
topic: str
class GuildChannel(
_TextChannelOptional, _VoiceChannelOptional, _CategoryChannelOptional, _StoreChannelOptional, _StageChannelOptional
):
guild_id: Snowflake guild_id: Snowflake
position: int position: int
permission_overwrites: List[PermissionOverwrite] permission_overwrites: List[PermissionOverwrite]
@@ -54,104 +81,11 @@ class _BaseGuildChannel(_BaseChannel):
parent_id: Optional[Snowflake] parent_id: Optional[Snowflake]
class PartialChannel(_BaseChannel): class DMChannel(PartialChannel):
type: ChannelType
class _TextChannelOptional(TypedDict, total=False):
topic: str
last_message_id: Optional[Snowflake]
last_pin_timestamp: str
rate_limit_per_user: int
default_auto_archive_duration: ThreadArchiveDuration
class TextChannel(_BaseGuildChannel, _TextChannelOptional):
type: Literal[0]
class NewsChannel(_BaseGuildChannel, _TextChannelOptional):
type: Literal[5]
VideoQualityMode = Literal[1, 2]
class _VoiceChannelOptional(TypedDict, total=False):
rtc_region: Optional[str]
video_quality_mode: VideoQualityMode
class VoiceChannel(_BaseGuildChannel, _VoiceChannelOptional):
type: Literal[2]
bitrate: int
user_limit: int
class CategoryChannel(_BaseGuildChannel):
type: Literal[4]
class StoreChannel(_BaseGuildChannel):
type: Literal[6]
class _StageChannelOptional(TypedDict, total=False):
rtc_region: Optional[str]
topic: str
class StageChannel(_BaseGuildChannel, _StageChannelOptional):
type: Literal[13]
bitrate: int
user_limit: int
class _ThreadChannelOptional(TypedDict, total=False):
member: ThreadMember
owner_id: Snowflake
rate_limit_per_user: int
last_message_id: Optional[Snowflake]
last_pin_timestamp: str
class ThreadChannel(_BaseChannel, _ThreadChannelOptional):
type: Literal[10, 11, 12]
guild_id: Snowflake
parent_id: Snowflake
owner_id: Snowflake
nsfw: bool
last_message_id: Optional[Snowflake]
rate_limit_per_user: int
message_count: int
member_count: int
thread_metadata: ThreadMetadata
GuildChannel = Union[TextChannel, NewsChannel, VoiceChannel, CategoryChannel, StoreChannel, StageChannel, ThreadChannel]
class DMChannel(_BaseChannel):
type: Literal[1]
last_message_id: Optional[Snowflake] last_message_id: Optional[Snowflake]
recipients: List[PartialUser] recipients: List[PartialUser]
class GroupDMChannel(_BaseChannel): class GroupDMChannel(DMChannel):
type: Literal[3]
icon: Optional[str] icon: Optional[str]
owner_id: Snowflake owner_id: Snowflake
Channel = Union[GuildChannel, DMChannel, GroupDMChannel]
PrivacyLevel = Literal[1, 2]
class StageInstance(TypedDict):
id: Snowflake
guild_id: Snowflake
channel_id: Snowflake
topic: str
privacy_level: PrivacyLevel
discoverable_disabled: bool

View File

@@ -1,76 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, TypedDict, Union
from .emoji import PartialEmoji
ComponentType = Literal[1, 2, 3]
ButtonStyle = Literal[1, 2, 3, 4, 5]
class ActionRow(TypedDict):
type: Literal[1]
components: List[Component]
class _ButtonComponentOptional(TypedDict, total=False):
custom_id: str
url: str
disabled: bool
emoji: PartialEmoji
label: str
class ButtonComponent(_ButtonComponentOptional):
type: Literal[2]
style: ButtonStyle
class _SelectMenuOptional(TypedDict, total=False):
placeholder: str
min_values: int
max_values: int
disabled: bool
class _SelectOptionsOptional(TypedDict, total=False):
description: str
emoji: PartialEmoji
class SelectOption(_SelectOptionsOptional):
label: str
value: str
default: bool
class SelectMenu(_SelectMenuOptional):
type: Literal[3]
custom_id: str
options: List[SelectOption]
Component = Union[ActionRow, ButtonComponent, SelectMenu]

View File

@@ -1,46 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import Optional, TypedDict
from .snowflake import Snowflake, SnowflakeList
from .user import User
class PartialEmoji(TypedDict):
id: Optional[Snowflake]
name: Optional[str]
class Emoji(PartialEmoji, total=False):
roles: SnowflakeList
user: User
require_colons: bool
managed: bool
animated: bool
available: bool
class EditEmoji(TypedDict):
name: str
roles: Optional[SnowflakeList]

View File

@@ -1,41 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import TypedDict
class SessionStartLimit(TypedDict):
total: int
remaining: int
reset_after: int
max_concurrency: int
class Gateway(TypedDict):
url: str
class GatewayBot(Gateway):
shards: int
session_start_limit: SessionStartLimit

View File

@@ -1,168 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import List, Literal, Optional, TypedDict
from .snowflake import Snowflake
from .channel import GuildChannel
from .voice import GuildVoiceState
from .welcome_screen import WelcomeScreen
from .activity import PartialPresenceUpdate
from .role import Role
from .member import Member
from .emoji import Emoji
from .user import User
from .threads import Thread
class Ban(TypedDict):
reason: Optional[str]
user: User
class _UnavailableGuildOptional(TypedDict, total=False):
unavailable: bool
class UnavailableGuild(_UnavailableGuildOptional):
id: Snowflake
class _GuildOptional(TypedDict, total=False):
icon_hash: Optional[str]
owner: bool
permissions: str
widget_enabled: bool
widget_channel_id: Optional[Snowflake]
joined_at: Optional[str]
large: bool
member_count: int
voice_states: List[GuildVoiceState]
members: List[Member]
channels: List[GuildChannel]
presences: List[PartialPresenceUpdate]
threads: List[Thread]
max_presences: Optional[int]
max_members: int
premium_subscription_count: int
max_video_channel_users: int
DefaultMessageNotificationLevel = Literal[0, 1]
ExplicitContentFilterLevel = Literal[0, 1, 2]
MFALevel = Literal[0, 1]
VerificationLevel = Literal[0, 1, 2, 3, 4]
NSFWLevel = Literal[0, 1, 2, 3]
PremiumTier = Literal[0, 1, 2, 3]
GuildFeature = Literal[
'ANIMATED_ICON',
'BANNER',
'COMMERCE',
'COMMUNITY',
'DISCOVERABLE',
'FEATURABLE',
'INVITE_SPLASH',
'MEMBER_VERIFICATION_GATE_ENABLED',
'MONETIZATION_ENABLED',
'MORE_EMOJI',
'MORE_STICKERS',
'NEWS',
'PARTNERED',
'PREVIEW_ENABLED',
'PRIVATE_THREADS',
'SEVEN_DAY_THREAD_ARCHIVE',
'THREE_DAY_THREAD_ARCHIVE',
'TICKETED_EVENTS_ENABLED',
'VANITY_URL',
'VERIFIED',
'VIP_REGIONS',
'WELCOME_SCREEN_ENABLED',
]
class _BaseGuildPreview(UnavailableGuild):
name: str
icon: Optional[str]
splash: Optional[str]
discovery_splash: Optional[str]
emojis: List[Emoji]
features: List[GuildFeature]
description: Optional[str]
class _GuildPreviewUnique(TypedDict):
approximate_member_count: int
approximate_presence_count: int
class GuildPreview(_BaseGuildPreview, _GuildPreviewUnique):
...
class Guild(_BaseGuildPreview, _GuildOptional):
owner_id: Snowflake
region: str
afk_channel_id: Optional[Snowflake]
afk_timeout: int
verification_level: VerificationLevel
default_message_notifications: DefaultMessageNotificationLevel
explicit_content_filter: ExplicitContentFilterLevel
roles: List[Role]
mfa_level: MFALevel
nsfw_level: NSFWLevel
application_id: Optional[Snowflake]
system_channel_id: Optional[Snowflake]
system_channel_flags: int
rules_channel_id: Optional[Snowflake]
vanity_url_code: Optional[str]
banner: Optional[str]
premium_tier: PremiumTier
preferred_locale: str
public_updates_channel_id: Optional[Snowflake]
class InviteGuild(Guild, total=False):
welcome_screen: WelcomeScreen
class GuildWithCounts(Guild, _GuildPreviewUnique):
...
class GuildPrune(TypedDict):
pruned: Optional[int]
class ChannelPositionUpdate(TypedDict):
id: Snowflake
position: Optional[int]
lock_permissions: Optional[bool]
parent_id: Optional[Snowflake]
class _RolePositionRequired(TypedDict):
id: Snowflake
class RolePositionUpdate(_RolePositionRequired, total=False):
position: Optional[Snowflake]

View File

@@ -1,82 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Literal, Optional, TypedDict, Union
from .snowflake import Snowflake
from .user import User
class _IntegrationApplicationOptional(TypedDict, total=False):
bot: User
class IntegrationApplication(_IntegrationApplicationOptional):
id: Snowflake
name: str
icon: Optional[str]
description: str
summary: str
class IntegrationAccount(TypedDict):
id: str
name: str
IntegrationExpireBehavior = Literal[0, 1]
class PartialIntegration(TypedDict):
id: Snowflake
name: str
type: IntegrationType
account: IntegrationAccount
IntegrationType = Literal['twitch', 'youtube', 'discord']
class BaseIntegration(PartialIntegration):
enabled: bool
syncing: bool
synced_at: str
user: User
expire_behavior: IntegrationExpireBehavior
expire_grace_period: int
class StreamIntegration(BaseIntegration):
role_id: Optional[Snowflake]
enable_emoticons: bool
subscriber_count: int
revoked: bool
class BotIntegration(BaseIntegration):
application: IntegrationApplication
Integration = Union[BaseIntegration, StreamIntegration, BotIntegration]

View File

@@ -1,236 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Optional, TYPE_CHECKING, Dict, TypedDict, Union, List, Literal
from .snowflake import Snowflake
from .components import Component, ComponentType
from .embed import Embed
from .channel import ChannelType
from .member import Member
from .role import Role
from .user import User
if TYPE_CHECKING:
from .message import AllowedMentions, Message
ApplicationCommandType = Literal[1, 2, 3]
class _ApplicationCommandOptional(TypedDict, total=False):
options: List[ApplicationCommandOption]
type: ApplicationCommandType
class ApplicationCommand(_ApplicationCommandOptional):
id: Snowflake
application_id: Snowflake
name: str
description: str
class _ApplicationCommandOptionOptional(TypedDict, total=False):
choices: List[ApplicationCommandOptionChoice]
options: List[ApplicationCommandOption]
ApplicationCommandOptionType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
class ApplicationCommandOption(_ApplicationCommandOptionOptional):
type: ApplicationCommandOptionType
name: str
description: str
required: bool
class ApplicationCommandOptionChoice(TypedDict):
name: str
value: Union[str, int]
ApplicationCommandPermissionType = Literal[1, 2]
class ApplicationCommandPermissions(TypedDict):
id: Snowflake
type: ApplicationCommandPermissionType
permission: bool
class BaseGuildApplicationCommandPermissions(TypedDict):
permissions: List[ApplicationCommandPermissions]
class PartialGuildApplicationCommandPermissions(BaseGuildApplicationCommandPermissions):
id: Snowflake
class GuildApplicationCommandPermissions(PartialGuildApplicationCommandPermissions):
application_id: Snowflake
guild_id: Snowflake
InteractionType = Literal[1, 2, 3]
class _ApplicationCommandInteractionDataOption(TypedDict):
name: str
class _ApplicationCommandInteractionDataOptionSubcommand(_ApplicationCommandInteractionDataOption):
type: Literal[1, 2]
options: List[ApplicationCommandInteractionDataOption]
class _ApplicationCommandInteractionDataOptionString(_ApplicationCommandInteractionDataOption):
type: Literal[3]
value: str
class _ApplicationCommandInteractionDataOptionInteger(_ApplicationCommandInteractionDataOption):
type: Literal[4]
value: int
class _ApplicationCommandInteractionDataOptionBoolean(_ApplicationCommandInteractionDataOption):
type: Literal[5]
value: bool
class _ApplicationCommandInteractionDataOptionSnowflake(_ApplicationCommandInteractionDataOption):
type: Literal[6, 7, 8, 9]
value: Snowflake
class _ApplicationCommandInteractionDataOptionNumber(_ApplicationCommandInteractionDataOption):
type: Literal[10]
value: float
ApplicationCommandInteractionDataOption = Union[
_ApplicationCommandInteractionDataOptionString,
_ApplicationCommandInteractionDataOptionInteger,
_ApplicationCommandInteractionDataOptionSubcommand,
_ApplicationCommandInteractionDataOptionBoolean,
_ApplicationCommandInteractionDataOptionSnowflake,
_ApplicationCommandInteractionDataOptionNumber,
]
class ApplicationCommandResolvedPartialChannel(TypedDict):
id: Snowflake
type: ChannelType
permissions: str
name: str
class ApplicationCommandInteractionDataResolved(TypedDict, total=False):
users: Dict[Snowflake, User]
members: Dict[Snowflake, Member]
roles: Dict[Snowflake, Role]
channels: Dict[Snowflake, ApplicationCommandResolvedPartialChannel]
class _ApplicationCommandInteractionDataOptional(TypedDict, total=False):
options: List[ApplicationCommandInteractionDataOption]
resolved: ApplicationCommandInteractionDataResolved
target_id: Snowflake
type: ApplicationCommandType
class ApplicationCommandInteractionData(_ApplicationCommandInteractionDataOptional):
id: Snowflake
name: str
class _ComponentInteractionDataOptional(TypedDict, total=False):
values: List[str]
class ComponentInteractionData(_ComponentInteractionDataOptional):
custom_id: str
component_type: ComponentType
InteractionData = Union[ApplicationCommandInteractionData, ComponentInteractionData]
class _InteractionOptional(TypedDict, total=False):
data: InteractionData
guild_id: Snowflake
channel_id: Snowflake
member: Member
user: User
message: Message
class Interaction(_InteractionOptional):
id: Snowflake
application_id: Snowflake
type: InteractionType
token: str
version: int
class InteractionApplicationCommandCallbackData(TypedDict, total=False):
tts: bool
content: str
embeds: List[Embed]
allowed_mentions: AllowedMentions
flags: int
components: List[Component]
InteractionResponseType = Literal[1, 4, 5, 6, 7]
class _InteractionResponseOptional(TypedDict, total=False):
data: InteractionApplicationCommandCallbackData
class InteractionResponse(_InteractionResponseOptional):
type: InteractionResponseType
class MessageInteraction(TypedDict):
id: Snowflake
type: InteractionType
name: str
user: User
class _EditApplicationCommandOptional(TypedDict, total=False):
description: str
options: Optional[List[ApplicationCommandOption]]
type: ApplicationCommandType
class EditApplicationCommand(_EditApplicationCommandOptional):
name: str
default_permission: bool

View File

@@ -1,99 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Literal, Optional, TypedDict, Union
from .snowflake import Snowflake
from .guild import InviteGuild, _GuildPreviewUnique
from .channel import PartialChannel
from .user import PartialUser
from .appinfo import PartialAppInfo
InviteTargetType = Literal[1, 2]
class _InviteOptional(TypedDict, total=False):
guild: InviteGuild
inviter: PartialUser
target_user: PartialUser
target_type: InviteTargetType
target_application: PartialAppInfo
class _InviteMetadata(TypedDict, total=False):
uses: int
max_uses: int
max_age: int
temporary: bool
created_at: str
expires_at: Optional[str]
class VanityInvite(_InviteMetadata):
code: Optional[str]
class IncompleteInvite(_InviteMetadata):
code: str
channel: PartialChannel
class Invite(IncompleteInvite, _InviteOptional):
...
class InviteWithCounts(Invite, _GuildPreviewUnique):
...
class _GatewayInviteCreateOptional(TypedDict, total=False):
guild_id: Snowflake
inviter: PartialUser
target_type: InviteTargetType
target_user: PartialUser
target_application: PartialAppInfo
class GatewayInviteCreate(_GatewayInviteCreateOptional):
channel_id: Snowflake
code: str
created_at: str
max_age: int
max_uses: int
temporary: bool
uses: bool
class _GatewayInviteDeleteOptional(TypedDict, total=False):
guild_id: Snowflake
class GatewayInviteDelete(_GatewayInviteDeleteOptional):
channel_id: Snowflake
code: str
GatewayInvite = Union[GatewayInviteCreate, GatewayInviteDelete]

View File

@@ -1,63 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import TypedDict
from .snowflake import SnowflakeList
from .user import User
class Nickname(TypedDict):
nick: str
class PartialMember(TypedDict):
roles: SnowflakeList
joined_at: str
deaf: str
mute: str
class Member(PartialMember, total=False):
avatar: str
user: User
nick: str
premium_since: str
pending: bool
permissions: str
class _OptionalMemberWithUser(PartialMember, total=False):
avatar: str
nick: str
premium_since: str
pending: bool
permissions: str
class MemberWithUser(_OptionalMemberWithUser):
user: User
class UserWithMember(User, total=False):
member: _OptionalMemberWithUser

View File

@@ -1,138 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, Optional, TypedDict, Union
from .snowflake import Snowflake, SnowflakeList
from .member import Member, UserWithMember
from .user import User
from .emoji import PartialEmoji
from .embed import Embed
from .channel import ChannelType
from .components import Component
from .interactions import MessageInteraction
from .sticker import StickerItem
class ChannelMention(TypedDict):
id: Snowflake
guild_id: Snowflake
type: ChannelType
name: str
class Reaction(TypedDict):
count: int
me: bool
emoji: PartialEmoji
class _AttachmentOptional(TypedDict, total=False):
height: Optional[int]
width: Optional[int]
content_type: str
spoiler: bool
class Attachment(_AttachmentOptional):
id: Snowflake
filename: str
size: int
url: str
proxy_url: str
MessageActivityType = Literal[1, 2, 3, 5]
class MessageActivity(TypedDict):
type: MessageActivityType
party_id: str
class _MessageApplicationOptional(TypedDict, total=False):
cover_image: str
class MessageApplication(_MessageApplicationOptional):
id: Snowflake
description: str
icon: Optional[str]
name: str
class MessageReference(TypedDict, total=False):
message_id: Snowflake
channel_id: Snowflake
guild_id: Snowflake
fail_if_not_exists: bool
class _MessageOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
mention_channels: List[ChannelMention]
reactions: List[Reaction]
nonce: Union[int, str]
webhook_id: Snowflake
activity: MessageActivity
application: MessageApplication
application_id: Snowflake
message_reference: MessageReference
flags: int
sticker_items: List[StickerItem]
referenced_message: Optional[Message]
interaction: MessageInteraction
components: List[Component]
MessageType = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19, 20, 21]
class Message(_MessageOptional):
id: Snowflake
channel_id: Snowflake
author: User
content: str
timestamp: str
edited_timestamp: Optional[str]
tts: bool
mention_everyone: bool
mentions: List[UserWithMember]
mention_roles: SnowflakeList
attachments: List[Attachment]
embeds: List[Embed]
pinned: bool
type: MessageType
AllowedMentionType = Literal['roles', 'users', 'everyone']
class AllowedMentions(TypedDict):
parse: List[AllowedMentionType]
roles: SnowflakeList
users: SnowflakeList
replied_user: bool

View File

@@ -1,87 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import TypedDict, List
from .snowflake import Snowflake
from .member import Member
from .emoji import PartialEmoji
class _MessageEventOptional(TypedDict, total=False):
guild_id: Snowflake
class MessageDeleteEvent(_MessageEventOptional):
id: Snowflake
channel_id: Snowflake
class BulkMessageDeleteEvent(_MessageEventOptional):
ids: List[Snowflake]
channel_id: Snowflake
class _ReactionActionEventOptional(TypedDict, total=False):
guild_id: Snowflake
member: Member
class MessageUpdateEvent(_MessageEventOptional):
id: Snowflake
channel_id: Snowflake
class ReactionActionEvent(_ReactionActionEventOptional):
user_id: Snowflake
channel_id: Snowflake
message_id: Snowflake
emoji: PartialEmoji
class _ReactionClearEventOptional(TypedDict, total=False):
guild_id: Snowflake
class ReactionClearEvent(_ReactionClearEventOptional):
channel_id: Snowflake
message_id: Snowflake
class _ReactionClearEmojiEventOptional(TypedDict, total=False):
guild_id: Snowflake
class ReactionClearEmojiEvent(_ReactionClearEmojiEventOptional):
channel_id: int
message_id: int
emoji: PartialEmoji
class _IntegrationDeleteEventOptional(TypedDict, total=False):
application_id: Snowflake
class IntegrationDeleteEvent(_IntegrationDeleteEventOptional):
id: Snowflake
guild_id: Snowflake

View File

@@ -1,49 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TypedDict
from .snowflake import Snowflake
class _RoleOptional(TypedDict, total=False):
tags: RoleTags
class Role(_RoleOptional):
id: Snowflake
name: str
color: int
hoist: bool
position: int
permissions: str
managed: bool
mentionable: bool
class RoleTags(TypedDict, total=False):
bot_id: Snowflake
integration_id: Snowflake
premium_subscriber: None

View File

@@ -22,7 +22,7 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from typing import List, Union from typing import List
Snowflake = Union[str, int] Snowflake = str
SnowflakeList = List[Snowflake] SnowflakeList = List[Snowflake]

View File

@@ -1,93 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, TypedDict, Union
from .snowflake import Snowflake
from .user import User
StickerFormatType = Literal[1, 2, 3]
class StickerItem(TypedDict):
id: Snowflake
name: str
format_type: StickerFormatType
class BaseSticker(TypedDict):
id: Snowflake
name: str
description: str
tags: str
format_type: StickerFormatType
class StandardSticker(BaseSticker):
type: Literal[1]
sort_value: int
pack_id: Snowflake
class _GuildStickerOptional(TypedDict, total=False):
user: User
class GuildSticker(BaseSticker, _GuildStickerOptional):
type: Literal[2]
available: bool
guild_id: Snowflake
Sticker = Union[BaseSticker, StandardSticker, GuildSticker]
class StickerPack(TypedDict):
id: Snowflake
stickers: List[StandardSticker]
name: str
sku_id: Snowflake
cover_sticker_id: Snowflake
description: str
banner_asset_id: Snowflake
class _CreateGuildStickerOptional(TypedDict, total=False):
description: str
class CreateGuildSticker(_CreateGuildStickerOptional):
name: str
tags: str
class EditGuildSticker(TypedDict, total=False):
name: str
tags: str
description: str
class ListPremiumStickerPacks(TypedDict):
sticker_packs: List[StickerPack]

View File

@@ -1,43 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import TypedDict, List, Optional
from .user import PartialUser
from .snowflake import Snowflake
class TeamMember(TypedDict):
user: PartialUser
membership_state: int
permissions: List[str]
team_id: Snowflake
class Team(TypedDict):
id: Snowflake
name: str
owner_id: Snowflake
members: List[TeamMember]
icon: Optional[str]

View File

@@ -1,49 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Optional, TypedDict
from .snowflake import Snowflake
from .user import User
from .guild import Guild
class CreateTemplate(TypedDict):
name: str
icon: Optional[bytes]
class Template(TypedDict):
code: str
name: str
description: Optional[str]
usage_count: int
creator_id: Snowflake
creator: User
created_at: str
updated_at: str
source_guild_id: Snowflake
serialized_source_guild: Guild
is_dirty: Optional[bool]

View File

@@ -1,75 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Literal, Optional, TypedDict
from .snowflake import Snowflake
ThreadType = Literal[10, 11, 12]
ThreadArchiveDuration = Literal[60, 1440, 4320, 10080]
class ThreadMember(TypedDict):
id: Snowflake
user_id: Snowflake
join_timestamp: str
flags: int
class _ThreadMetadataOptional(TypedDict, total=False):
archiver_id: Snowflake
locked: bool
invitable: bool
class ThreadMetadata(_ThreadMetadataOptional):
archived: bool
auto_archive_duration: ThreadArchiveDuration
archive_timestamp: str
class _ThreadOptional(TypedDict, total=False):
member: ThreadMember
last_message_id: Optional[Snowflake]
last_pin_timestamp: Optional[Snowflake]
class Thread(_ThreadOptional):
id: Snowflake
guild_id: Snowflake
parent_id: Snowflake
owner_id: Snowflake
name: str
type: ThreadType
member_count: int
message_count: int
rate_limit_per_user: int
thread_metadata: ThreadMetadata
class ThreadPaginationPayload(TypedDict):
threads: List[Thread]
members: List[ThreadMember]
has_more: bool

View File

@@ -23,7 +23,7 @@ DEALINGS IN THE SOFTWARE.
""" """
from .snowflake import Snowflake from .snowflake import Snowflake
from typing import Literal, Optional, TypedDict from typing import Optional, TypedDict
class PartialUser(TypedDict): class PartialUser(TypedDict):
@@ -31,18 +31,3 @@ class PartialUser(TypedDict):
username: str username: str
discriminator: str discriminator: str
avatar: Optional[str] avatar: Optional[str]
PremiumType = Literal[0, 1, 2]
class User(PartialUser, total=False):
bot: bool
system: bool
mfa_enabled: bool
local: str
verified: bool
email: Optional[str]
flags: int
premium_type: PremiumType
public_flags: int

View File

@@ -1,85 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import Optional, TypedDict, List, Literal
from .snowflake import Snowflake
from .member import MemberWithUser
SupportedModes = Literal['xsalsa20_poly1305_lite', 'xsalsa20_poly1305_suffix', 'xsalsa20_poly1305']
class _PartialVoiceStateOptional(TypedDict, total=False):
member: MemberWithUser
self_stream: bool
class _VoiceState(_PartialVoiceStateOptional):
user_id: Snowflake
session_id: str
deaf: bool
mute: bool
self_deaf: bool
self_mute: bool
self_video: bool
suppress: bool
class GuildVoiceState(_VoiceState):
channel_id: Snowflake
class VoiceState(_VoiceState, total=False):
channel_id: Optional[Snowflake]
guild_id: Snowflake
class VoiceRegion(TypedDict):
id: str
name: str
vip: bool
optimal: bool
deprecated: bool
custom: bool
class VoiceServerUpdate(TypedDict):
token: str
guild_id: Snowflake
endpoint: Optional[str]
class VoiceIdentify(TypedDict):
server_id: Snowflake
user_id: Snowflake
session_id: str
token: str
class VoiceReady(TypedDict):
ssrc: int
ip: str
port: int
modes: List[SupportedModes]
heartbeat_interval: int

View File

@@ -1,70 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Literal, Optional, TypedDict
from .snowflake import Snowflake
from .user import User
from .channel import PartialChannel
class SourceGuild(TypedDict):
id: int
name: str
icon: str
class _WebhookOptional(TypedDict, total=False):
guild_id: Snowflake
user: User
token: str
WebhookType = Literal[1, 2, 3]
class _FollowerWebhookOptional(TypedDict, total=False):
source_channel: PartialChannel
source_guild: SourceGuild
class FollowerWebhook(_FollowerWebhookOptional):
channel_id: Snowflake
webhook_id: Snowflake
class PartialWebhook(_WebhookOptional):
id: Snowflake
type: WebhookType
class _FullWebhook(TypedDict, total=False):
name: Optional[str]
avatar: Optional[str]
channel_id: Snowflake
application_id: Optional[Snowflake]
class Webhook(PartialWebhook, _FullWebhook):
...

View File

@@ -1,40 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Optional, TypedDict
from .snowflake import Snowflake
class WelcomeScreen(TypedDict):
description: str
welcome_channels: List[WelcomeScreenChannel]
class WelcomeScreenChannel(TypedDict):
channel_id: Snowflake
description: str
emoji_id: Optional[Snowflake]
emoji_name: Optional[str]

View File

@@ -1,63 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from typing import List, Optional, TypedDict
from .activity import Activity
from .snowflake import Snowflake
from .user import User
class WidgetChannel(TypedDict):
id: Snowflake
name: str
position: int
class WidgetMember(User, total=False):
nick: str
game: Activity
status: str
avatar_url: str
deaf: bool
self_deaf: bool
mute: bool
self_mute: bool
suppress: bool
class _WidgetOptional(TypedDict, total=False):
channels: List[WidgetChannel]
members: List[WidgetMember]
presence_count: int
class Widget(_WidgetOptional):
id: Snowflake
name: str
instant_invite: str
class WidgetSettings(TypedDict):
enabled: bool
channel_id: Optional[Snowflake]

View File

@@ -1,15 +0,0 @@
"""
discord.ui
~~~~~~~~~~~
Bot UI Kit helper for the Discord API
:copyright: (c) 2015-present Rapptz
:license: MIT, see LICENSE for more details.
"""
from .view import *
from .item import *
from .button import *
from .select import *

View File

@@ -1,290 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Callable, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union
import inspect
import os
from .item import Item, ItemCallbackType
from ..enums import ButtonStyle, ComponentType
from ..partial_emoji import PartialEmoji, _EmojiTag
from ..components import Button as ButtonComponent
__all__ = (
'Button',
'button',
)
if TYPE_CHECKING:
from .view import View
from ..emoji import Emoji
B = TypeVar('B', bound='Button')
V = TypeVar('V', bound='View', covariant=True)
class Button(Item[V]):
"""Represents a UI button.
.. versionadded:: 2.0
Parameters
------------
style: :class:`discord.ButtonStyle`
The style of the button.
custom_id: Optional[:class:`str`]
The ID of the button that gets received during an interaction.
If this button is for a URL, it does not have a custom ID.
url: Optional[:class:`str`]
The URL this button sends you to.
disabled: :class:`bool`
Whether the button is disabled or not.
label: Optional[:class:`str`]
The label of the button, if any.
emoji: Optional[Union[:class:`.PartialEmoji`, :class:`.Emoji`, :class:`str`]]
The emoji of the button, if available.
row: Optional[:class:`int`]
The relative row this button belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
like to control the relative positioning of the row then passing an index is advised.
For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
"""
__item_repr_attributes__: Tuple[str, ...] = (
'style',
'url',
'disabled',
'label',
'emoji',
'row',
)
def __init__(
self,
*,
style: ButtonStyle = ButtonStyle.secondary,
label: Optional[str] = None,
disabled: bool = False,
custom_id: Optional[str] = None,
url: Optional[str] = None,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
row: Optional[int] = None,
):
super().__init__()
if custom_id is not None and url is not None:
raise TypeError('cannot mix both url and custom_id with Button')
self._provided_custom_id = custom_id is not None
if url is None and custom_id is None:
custom_id = os.urandom(16).hex()
if url is not None:
style = ButtonStyle.link
if emoji is not None:
if isinstance(emoji, str):
emoji = PartialEmoji.from_str(emoji)
elif isinstance(emoji, _EmojiTag):
emoji = emoji._to_partial()
else:
raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}')
self._underlying = ButtonComponent._raw_construct(
type=ComponentType.button,
custom_id=custom_id,
url=url,
disabled=disabled,
label=label,
style=style,
emoji=emoji,
)
self.row = row
@property
def style(self) -> ButtonStyle:
""":class:`discord.ButtonStyle`: The style of the button."""
return self._underlying.style
@style.setter
def style(self, value: ButtonStyle):
self._underlying.style = value
@property
def custom_id(self) -> Optional[str]:
"""Optional[:class:`str`]: The ID of the button that gets received during an interaction.
If this button is for a URL, it does not have a custom ID.
"""
return self._underlying.custom_id
@custom_id.setter
def custom_id(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError('custom_id must be None or str')
self._underlying.custom_id = value
@property
def url(self) -> Optional[str]:
"""Optional[:class:`str`]: The URL this button sends you to."""
return self._underlying.url
@url.setter
def url(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError('url must be None or str')
self._underlying.url = value
@property
def disabled(self) -> bool:
""":class:`bool`: Whether the button is disabled or not."""
return self._underlying.disabled
@disabled.setter
def disabled(self, value: bool):
self._underlying.disabled = bool(value)
@property
def label(self) -> Optional[str]:
"""Optional[:class:`str`]: The label of the button, if available."""
return self._underlying.label
@label.setter
def label(self, value: Optional[str]):
self._underlying.label = str(value) if value is not None else value
@property
def emoji(self) -> Optional[PartialEmoji]:
"""Optional[:class:`.PartialEmoji`]: The emoji of the button, if available."""
return self._underlying.emoji
@emoji.setter
def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore
if value is not None:
if isinstance(value, str):
self._underlying.emoji = PartialEmoji.from_str(value)
elif isinstance(value, _EmojiTag):
self._underlying.emoji = value._to_partial()
else:
raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead')
else:
self._underlying.emoji = None
@classmethod
def from_component(cls: Type[B], button: ButtonComponent) -> B:
return cls(
style=button.style,
label=button.label,
disabled=button.disabled,
custom_id=button.custom_id,
url=button.url,
emoji=button.emoji,
row=None,
)
@property
def type(self) -> ComponentType:
return self._underlying.type
def to_component_dict(self):
return self._underlying.to_dict()
def is_dispatchable(self) -> bool:
return self.custom_id is not None
def is_persistent(self) -> bool:
if self.style is ButtonStyle.link:
return self.url is not None
return super().is_persistent()
def refresh_component(self, button: ButtonComponent) -> None:
self._underlying = button
def button(
*,
label: Optional[str] = None,
custom_id: Optional[str] = None,
disabled: bool = False,
style: ButtonStyle = ButtonStyle.secondary,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
row: Optional[int] = None,
) -> Callable[[ItemCallbackType], ItemCallbackType]:
"""A decorator that attaches a button to a component.
The function being decorated should have three parameters, ``self`` representing
the :class:`discord.ui.View`, the :class:`discord.ui.Button` being pressed and
the :class:`discord.Interaction` you receive.
.. note::
Buttons with a URL cannot be created with this function.
Consider creating a :class:`Button` manually instead.
This is because buttons with a URL do not have a callback
associated with them since Discord does not do any processing
with it.
Parameters
------------
label: Optional[:class:`str`]
The label of the button, if any.
custom_id: Optional[:class:`str`]
The ID of the button that gets received during an interaction.
It is recommended not to set this parameter to prevent conflicts.
style: :class:`.ButtonStyle`
The style of the button. Defaults to :attr:`.ButtonStyle.grey`.
disabled: :class:`bool`
Whether the button is disabled or not. Defaults to ``False``.
emoji: Optional[Union[:class:`str`, :class:`.Emoji`, :class:`.PartialEmoji`]]
The emoji of the button. This can be in string form or a :class:`.PartialEmoji`
or a full :class:`.Emoji`.
row: Optional[:class:`int`]
The relative row this button belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
like to control the relative positioning of the row then passing an index is advised.
For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
"""
def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func):
raise TypeError('button function must be a coroutine function')
func.__discord_ui_model_type__ = Button
func.__discord_ui_model_kwargs__ = {
'style': style,
'custom_id': custom_id,
'url': None,
'disabled': disabled,
'label': label,
'emoji': emoji,
'row': row,
}
return func
return decorator

View File

@@ -1,131 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Callable, Coroutine, Dict, Generic, Optional, TYPE_CHECKING, Tuple, Type, TypeVar
from ..interactions import Interaction
__all__ = (
'Item',
)
if TYPE_CHECKING:
from ..enums import ComponentType
from .view import View
from ..components import Component
I = TypeVar('I', bound='Item')
V = TypeVar('V', bound='View', covariant=True)
ItemCallbackType = Callable[[Any, I, Interaction], Coroutine[Any, Any, Any]]
class Item(Generic[V]):
"""Represents the base UI item that all UI components inherit from.
The current UI items supported are:
- :class:`discord.ui.Button`
- :class:`discord.ui.Select`
.. versionadded:: 2.0
"""
__item_repr_attributes__: Tuple[str, ...] = ('row',)
def __init__(self):
self._view: Optional[V] = None
self._row: Optional[int] = None
self._rendered_row: Optional[int] = None
# This works mostly well but there is a gotcha with
# the interaction with from_component, since that technically provides
# a custom_id most dispatchable items would get this set to True even though
# it might not be provided by the library user. However, this edge case doesn't
# actually affect the intended purpose of this check because from_component is
# only called upon edit and we're mainly interested during initial creation time.
self._provided_custom_id: bool = False
def to_component_dict(self) -> Dict[str, Any]:
raise NotImplementedError
def refresh_component(self, component: Component) -> None:
return None
def refresh_state(self, interaction: Interaction) -> None:
return None
@classmethod
def from_component(cls: Type[I], component: Component) -> I:
return cls()
@property
def type(self) -> ComponentType:
raise NotImplementedError
def is_dispatchable(self) -> bool:
return False
def is_persistent(self) -> bool:
return self._provided_custom_id
def __repr__(self) -> str:
attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__item_repr_attributes__)
return f'<{self.__class__.__name__} {attrs}>'
@property
def row(self) -> Optional[int]:
return self._row
@row.setter
def row(self, value: Optional[int]):
if value is None:
self._row = None
elif 5 > value >= 0:
self._row = value
else:
raise ValueError('row cannot be negative or greater than or equal to 5')
@property
def width(self) -> int:
return 1
@property
def view(self) -> Optional[V]:
"""Optional[:class:`View`]: The underlying view for this item."""
return self._view
async def callback(self, interaction: Interaction):
"""|coro|
The callback associated with this UI item.
This can be overriden by subclasses.
Parameters
-----------
interaction: :class:`.Interaction`
The interaction that triggered this UI item.
"""
pass

View File

@@ -1,357 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type, Callable, Union
import inspect
import os
from .item import Item, ItemCallbackType
from ..enums import ComponentType
from ..partial_emoji import PartialEmoji
from ..emoji import Emoji
from ..interactions import Interaction
from ..utils import MISSING
from ..components import (
SelectOption,
SelectMenu,
)
__all__ = (
'Select',
'select',
)
if TYPE_CHECKING:
from .view import View
from ..types.components import SelectMenu as SelectMenuPayload
from ..types.interactions import (
ComponentInteractionData,
)
S = TypeVar('S', bound='Select')
V = TypeVar('V', bound='View', covariant=True)
class Select(Item[V]):
"""Represents a UI select menu.
This is usually represented as a drop down menu.
In order to get the selected items that the user has chosen, use :attr:`Select.values`.
.. versionadded:: 2.0
Parameters
------------
custom_id: :class:`str`
The ID of the select menu that gets received during an interaction.
If not given then one is generated for you.
placeholder: Optional[:class:`str`]
The placeholder text that is shown if nothing is selected, if any.
min_values: :class:`int`
The minimum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25.
max_values: :class:`int`
The maximum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25.
options: List[:class:`discord.SelectOption`]
A list of options that can be selected in this menu.
disabled: :class:`bool`
Whether the select is disabled or not.
row: Optional[:class:`int`]
The relative row this select menu belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
like to control the relative positioning of the row then passing an index is advised.
For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
"""
__item_repr_attributes__: Tuple[str, ...] = (
'placeholder',
'min_values',
'max_values',
'options',
'disabled',
)
def __init__(
self,
*,
custom_id: str = MISSING,
placeholder: Optional[str] = None,
min_values: int = 1,
max_values: int = 1,
options: List[SelectOption] = MISSING,
disabled: bool = False,
row: Optional[int] = None,
) -> None:
super().__init__()
self._selected_values: List[str] = []
self._provided_custom_id = custom_id is not MISSING
custom_id = os.urandom(16).hex() if custom_id is MISSING else custom_id
options = [] if options is MISSING else options
self._underlying = SelectMenu._raw_construct(
custom_id=custom_id,
type=ComponentType.select,
placeholder=placeholder,
min_values=min_values,
max_values=max_values,
options=options,
disabled=disabled,
)
self.row = row
@property
def custom_id(self) -> str:
""":class:`str`: The ID of the select menu that gets received during an interaction."""
return self._underlying.custom_id
@custom_id.setter
def custom_id(self, value: str):
if not isinstance(value, str):
raise TypeError('custom_id must be None or str')
self._underlying.custom_id = value
@property
def placeholder(self) -> Optional[str]:
"""Optional[:class:`str`]: The placeholder text that is shown if nothing is selected, if any."""
return self._underlying.placeholder
@placeholder.setter
def placeholder(self, value: Optional[str]):
if value is not None and not isinstance(value, str):
raise TypeError('placeholder must be None or str')
self._underlying.placeholder = value
@property
def min_values(self) -> int:
""":class:`int`: The minimum number of items that must be chosen for this select menu."""
return self._underlying.min_values
@min_values.setter
def min_values(self, value: int):
self._underlying.min_values = int(value)
@property
def max_values(self) -> int:
""":class:`int`: The maximum number of items that must be chosen for this select menu."""
return self._underlying.max_values
@max_values.setter
def max_values(self, value: int):
self._underlying.max_values = int(value)
@property
def options(self) -> List[SelectOption]:
"""List[:class:`discord.SelectOption`]: A list of options that can be selected in this menu."""
return self._underlying.options
@options.setter
def options(self, value: List[SelectOption]):
if not isinstance(value, list):
raise TypeError('options must be a list of SelectOption')
if not all(isinstance(obj, SelectOption) for obj in value):
raise TypeError('all list items must subclass SelectOption')
self._underlying.options = value
def add_option(
self,
*,
label: str,
value: str = MISSING,
description: Optional[str] = None,
emoji: Optional[Union[str, Emoji, PartialEmoji]] = None,
default: bool = False,
):
"""Adds an option to the select menu.
To append a pre-existing :class:`discord.SelectOption` use the
:meth:`append_option` method instead.
Parameters
-----------
label: :class:`str`
The label of the option. This is displayed to users.
Can only be up to 100 characters.
value: :class:`str`
The value of the option. This is not displayed to users.
If not given, defaults to the label. Can only be up to 100 characters.
description: Optional[:class:`str`]
An additional description of the option, if any.
Can only be up to 100 characters.
emoji: Optional[Union[:class:`str`, :class:`.Emoji`, :class:`.PartialEmoji`]]
The emoji of the option, if available. This can either be a string representing
the custom or unicode emoji or an instance of :class:`.PartialEmoji` or :class:`.Emoji`.
default: :class:`bool`
Whether this option is selected by default.
Raises
-------
ValueError
The number of options exceeds 25.
"""
option = SelectOption(
label=label,
value=value,
description=description,
emoji=emoji,
default=default,
)
self.append_option(option)
def append_option(self, option: SelectOption):
"""Appends an option to the select menu.
Parameters
-----------
option: :class:`discord.SelectOption`
The option to append to the select menu.
Raises
-------
ValueError
The number of options exceeds 25.
"""
if len(self._underlying.options) > 25:
raise ValueError('maximum number of options already provided')
self._underlying.options.append(option)
@property
def disabled(self) -> bool:
""":class:`bool`: Whether the select is disabled or not."""
return self._underlying.disabled
@disabled.setter
def disabled(self, value: bool):
self._underlying.disabled = bool(value)
@property
def values(self) -> List[str]:
"""List[:class:`str`]: A list of values that have been selected by the user."""
return self._selected_values
@property
def width(self) -> int:
return 5
def to_component_dict(self) -> SelectMenuPayload:
return self._underlying.to_dict()
def refresh_component(self, component: SelectMenu) -> None:
self._underlying = component
def refresh_state(self, interaction: Interaction) -> None:
data: ComponentInteractionData = interaction.data # type: ignore
self._selected_values = data.get('values', [])
@classmethod
def from_component(cls: Type[S], component: SelectMenu) -> S:
return cls(
custom_id=component.custom_id,
placeholder=component.placeholder,
min_values=component.min_values,
max_values=component.max_values,
options=component.options,
disabled=component.disabled,
row=None,
)
@property
def type(self) -> ComponentType:
return self._underlying.type
def is_dispatchable(self) -> bool:
return True
def select(
*,
placeholder: Optional[str] = None,
custom_id: str = MISSING,
min_values: int = 1,
max_values: int = 1,
options: List[SelectOption] = MISSING,
disabled: bool = False,
row: Optional[int] = None,
) -> Callable[[ItemCallbackType], ItemCallbackType]:
"""A decorator that attaches a select menu to a component.
The function being decorated should have three parameters, ``self`` representing
the :class:`discord.ui.View`, the :class:`discord.ui.Select` being pressed and
the :class:`discord.Interaction` you receive.
In order to get the selected items that the user has chosen within the callback
use :attr:`Select.values`.
Parameters
------------
placeholder: Optional[:class:`str`]
The placeholder text that is shown if nothing is selected, if any.
custom_id: :class:`str`
The ID of the select menu that gets received during an interaction.
It is recommended not to set this parameter to prevent conflicts.
row: Optional[:class:`int`]
The relative row this select menu belongs to. A Discord component can only have 5
rows. By default, items are arranged automatically into those 5 rows. If you'd
like to control the relative positioning of the row then passing an index is advised.
For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic
ordering. The row number must be between 0 and 4 (i.e. zero indexed).
min_values: :class:`int`
The minimum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25.
max_values: :class:`int`
The maximum number of items that must be chosen for this select menu.
Defaults to 1 and must be between 1 and 25.
options: List[:class:`discord.SelectOption`]
A list of options that can be selected in this menu.
disabled: :class:`bool`
Whether the select is disabled or not. Defaults to ``False``.
"""
def decorator(func: ItemCallbackType) -> ItemCallbackType:
if not inspect.iscoroutinefunction(func):
raise TypeError('select function must be a coroutine function')
func.__discord_ui_model_type__ = Select
func.__discord_ui_model_kwargs__ = {
'placeholder': placeholder,
'custom_id': custom_id,
'row': row,
'min_values': min_values,
'max_values': max_values,
'options': options,
'disabled': disabled,
}
return func
return decorator

View File

@@ -1,529 +0,0 @@
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from __future__ import annotations
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Sequence, TYPE_CHECKING, Tuple
from functools import partial
from itertools import groupby
import traceback
import asyncio
import sys
import time
import os
from .item import Item, ItemCallbackType
from ..components import (
Component,
ActionRow as ActionRowComponent,
_component_factory,
Button as ButtonComponent,
SelectMenu as SelectComponent,
)
__all__ = (
'View',
)
if TYPE_CHECKING:
from ..interactions import Interaction
from ..message import Message
from ..types.components import Component as ComponentPayload
from ..state import ConnectionState
def _walk_all_components(components: List[Component]) -> Iterator[Component]:
for item in components:
if isinstance(item, ActionRowComponent):
yield from item.children
else:
yield item
def _component_to_item(component: Component) -> Item:
if isinstance(component, ButtonComponent):
from .button import Button
return Button.from_component(component)
if isinstance(component, SelectComponent):
from .select import Select
return Select.from_component(component)
return Item.from_component(component)
class _ViewWeights:
__slots__ = (
'weights',
)
def __init__(self, children: List[Item]):
self.weights: List[int] = [0, 0, 0, 0, 0]
key = lambda i: sys.maxsize if i.row is None else i.row
children = sorted(children, key=key)
for row, group in groupby(children, key=key):
for item in group:
self.add_item(item)
def find_open_space(self, item: Item) -> int:
for index, weight in enumerate(self.weights):
if weight + item.width <= 5:
return index
raise ValueError('could not find open space for item')
def add_item(self, item: Item) -> None:
if item.row is not None:
total = self.weights[item.row] + item.width
if total > 5:
raise ValueError(f'item would not fit at row {item.row} ({total} > 5 width)')
self.weights[item.row] = total
item._rendered_row = item.row
else:
index = self.find_open_space(item)
self.weights[index] += item.width
item._rendered_row = index
def remove_item(self, item: Item) -> None:
if item._rendered_row is not None:
self.weights[item._rendered_row] -= item.width
item._rendered_row = None
def clear(self) -> None:
self.weights = [0, 0, 0, 0, 0]
class View:
"""Represents a UI view.
This object must be inherited to create a UI within Discord.
.. versionadded:: 2.0
Parameters
-----------
timeout: Optional[:class:`float`]
Timeout in seconds from last interaction with the UI before no longer accepting input.
If ``None`` then there is no timeout.
Attributes
------------
timeout: Optional[:class:`float`]
Timeout from last interaction with the UI before no longer accepting input.
If ``None`` then there is no timeout.
children: List[:class:`Item`]
The list of children attached to this view.
"""
__discord_ui_view__: ClassVar[bool] = True
__view_children_items__: ClassVar[List[ItemCallbackType]] = []
def __init_subclass__(cls) -> None:
children: List[ItemCallbackType] = []
for base in reversed(cls.__mro__):
for member in base.__dict__.values():
if hasattr(member, '__discord_ui_model_type__'):
children.append(member)
if len(children) > 25:
raise TypeError('View cannot have more than 25 children')
cls.__view_children_items__ = children
def __init__(self, *, timeout: Optional[float] = 180.0):
self.timeout = timeout
self.children: List[Item] = []
for func in self.__view_children_items__:
item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__)
item.callback = partial(func, self, item)
item._view = self
setattr(self, func.__name__, item)
self.children.append(item)
self.__weights = _ViewWeights(self.children)
loop = asyncio.get_running_loop()
self.id: str = os.urandom(16).hex()
self.__cancel_callback: Optional[Callable[[View], None]] = None
self.__timeout_expiry: Optional[float] = None
self.__timeout_task: Optional[asyncio.Task[None]] = None
self.__stopped: asyncio.Future[bool] = loop.create_future()
def __repr__(self) -> str:
return f'<{self.__class__.__name__} timeout={self.timeout} children={len(self.children)}>'
async def __timeout_task_impl(self) -> None:
while True:
# Guard just in case someone changes the value of the timeout at runtime
if self.timeout is None:
return
if self.__timeout_expiry is None:
return self._dispatch_timeout()
# Check if we've elapsed our currently set timeout
now = time.monotonic()
if now >= self.__timeout_expiry:
return self._dispatch_timeout()
# Wait N seconds to see if timeout data has been refreshed
await asyncio.sleep(self.__timeout_expiry - now)
def to_components(self) -> List[Dict[str, Any]]:
def key(item: Item) -> int:
return item._rendered_row or 0
children = sorted(self.children, key=key)
components: List[Dict[str, Any]] = []
for _, group in groupby(children, key=key):
children = [item.to_component_dict() for item in group]
if not children:
continue
components.append(
{
'type': 1,
'components': children,
}
)
return components
@classmethod
def from_message(cls, message: Message, /, *, timeout: Optional[float] = 180.0) -> View:
"""Converts a message's components into a :class:`View`.
The :attr:`.Message.components` of a message are read-only
and separate types from those in the ``discord.ui`` namespace.
In order to modify and edit message components they must be
converted into a :class:`View` first.
Parameters
-----------
message: :class:`discord.Message`
The message with components to convert into a view.
timeout: Optional[:class:`float`]
The timeout of the converted view.
Returns
--------
:class:`View`
The converted view. This always returns a :class:`View` and not
one of its subclasses.
"""
view = View(timeout=timeout)
for component in _walk_all_components(message.components):
view.add_item(_component_to_item(component))
return view
@property
def _expires_at(self) -> Optional[float]:
if self.timeout:
return time.monotonic() + self.timeout
return None
def add_item(self, item: Item) -> None:
"""Adds an item to the view.
Parameters
-----------
item: :class:`Item`
The item to add to the view.
Raises
--------
TypeError
An :class:`Item` was not passed.
ValueError
Maximum number of children has been exceeded (25)
or the row the item is trying to be added to is full.
"""
if len(self.children) > 25:
raise ValueError('maximum number of children exceeded')
if not isinstance(item, Item):
raise TypeError(f'expected Item not {item.__class__!r}')
self.__weights.add_item(item)
item._view = self
self.children.append(item)
def remove_item(self, item: Item) -> None:
"""Removes an item from the view.
Parameters
-----------
item: :class:`Item`
The item to remove from the view.
"""
try:
self.children.remove(item)
except ValueError:
pass
else:
self.__weights.remove_item(item)
def clear_items(self) -> None:
"""Removes all items from the view."""
self.children.clear()
self.__weights.clear()
async def interaction_check(self, interaction: Interaction) -> bool:
"""|coro|
A callback that is called when an interaction happens within the view
that checks whether the view should process item callbacks for the interaction.
This is useful to override if, for example, you want to ensure that the
interaction author is a given user.
The default implementation of this returns ``True``.
.. note::
If an exception occurs within the body then the check
is considered a failure and :meth:`on_error` is called.
Parameters
-----------
interaction: :class:`~discord.Interaction`
The interaction that occurred.
Returns
---------
:class:`bool`
Whether the view children's callbacks should be called.
"""
return True
async def on_timeout(self) -> None:
"""|coro|
A callback that is called when a view's timeout elapses without being explicitly stopped.
"""
pass
async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None:
"""|coro|
A callback that is called when an item's callback or :meth:`interaction_check`
fails with an error.
The default implementation prints the traceback to stderr.
Parameters
-----------
error: :class:`Exception`
The exception that was raised.
item: :class:`Item`
The item that failed the dispatch.
interaction: :class:`~discord.Interaction`
The interaction that led to the failure.
"""
print(f'Ignoring exception in view {self} for item {item}:', file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
async def _scheduled_task(self, item: Item, interaction: Interaction):
try:
if self.timeout:
self.__timeout_expiry = time.monotonic() + self.timeout
allow = await self.interaction_check(interaction)
if not allow:
return
await item.callback(interaction)
if not interaction.response._responded:
await interaction.response.defer()
except Exception as e:
return await self.on_error(e, item, interaction)
def _start_listening_from_store(self, store: ViewStore) -> None:
self.__cancel_callback = partial(store.remove_view)
if self.timeout:
loop = asyncio.get_running_loop()
if self.__timeout_task is not None:
self.__timeout_task.cancel()
self.__timeout_expiry = time.monotonic() + self.timeout
self.__timeout_task = loop.create_task(self.__timeout_task_impl())
def _dispatch_timeout(self):
if self.__stopped.done():
return
self.__stopped.set_result(True)
asyncio.create_task(self.on_timeout(), name=f'discord-ui-view-timeout-{self.id}')
def _dispatch_item(self, item: Item, interaction: Interaction):
if self.__stopped.done():
return
asyncio.create_task(self._scheduled_task(item, interaction), name=f'discord-ui-view-dispatch-{self.id}')
def refresh(self, components: List[Component]):
# This is pretty hacky at the moment
# fmt: off
old_state: Dict[Tuple[int, str], Item] = {
(item.type.value, item.custom_id): item # type: ignore
for item in self.children
if item.is_dispatchable()
}
# fmt: on
children: List[Item] = []
for component in _walk_all_components(components):
try:
older = old_state[(component.type.value, component.custom_id)] # type: ignore
except (KeyError, AttributeError):
children.append(_component_to_item(component))
else:
older.refresh_component(component)
children.append(older)
self.children = children
def stop(self) -> None:
"""Stops listening to interaction events from this view.
This operation cannot be undone.
"""
if not self.__stopped.done():
self.__stopped.set_result(False)
self.__timeout_expiry = None
if self.__timeout_task is not None:
self.__timeout_task.cancel()
self.__timeout_task = None
if self.__cancel_callback:
self.__cancel_callback(self)
self.__cancel_callback = None
def is_finished(self) -> bool:
""":class:`bool`: Whether the view has finished interacting."""
return self.__stopped.done()
def is_dispatching(self) -> bool:
""":class:`bool`: Whether the view has been added for dispatching purposes."""
return self.__cancel_callback is not None
def is_persistent(self) -> bool:
""":class:`bool`: Whether the view is set up as persistent.
A persistent view has all their components with a set ``custom_id`` and
a :attr:`timeout` set to ``None``.
"""
return self.timeout is None and all(item.is_persistent() for item in self.children)
async def wait(self) -> bool:
"""Waits until the view has finished interacting.
A view is considered finished when :meth:`stop` is called
or it times out.
Returns
--------
:class:`bool`
If ``True``, then the view timed out. If ``False`` then
the view finished normally.
"""
return await self.__stopped
class ViewStore:
def __init__(self, state: ConnectionState):
# (component_type, message_id, custom_id): (View, Item)
self._views: Dict[Tuple[int, Optional[int], str], Tuple[View, Item]] = {}
# message_id: View
self._synced_message_views: Dict[int, View] = {}
self._state: ConnectionState = state
@property
def persistent_views(self) -> Sequence[View]:
# fmt: off
views = {
view.id: view
for (_, (view, _)) in self._views.items()
if view.is_persistent()
}
# fmt: on
return list(views.values())
def __verify_integrity(self):
to_remove: List[Tuple[int, Optional[int], str]] = []
for (k, (view, _)) in self._views.items():
if view.is_finished():
to_remove.append(k)
for k in to_remove:
del self._views[k]
def add_view(self, view: View, message_id: Optional[int] = None):
self.__verify_integrity()
view._start_listening_from_store(self)
for item in view.children:
if item.is_dispatchable():
self._views[(item.type.value, message_id, item.custom_id)] = (view, item) # type: ignore
if message_id is not None:
self._synced_message_views[message_id] = view
def remove_view(self, view: View):
for item in view.children:
if item.is_dispatchable():
self._views.pop((item.type.value, item.custom_id), None) # type: ignore
for key, value in self._synced_message_views.items():
if value.id == view.id:
del self._synced_message_views[key]
break
def dispatch(self, component_type: int, custom_id: str, interaction: Interaction):
self.__verify_integrity()
message_id: Optional[int] = interaction.message and interaction.message.id
key = (component_type, message_id, custom_id)
# Fallback to None message_id searches in case a persistent view
# was added without an associated message_id
value = self._views.get(key) or self._views.get((component_type, None, custom_id))
if value is None:
return
view, item = value
item.refresh_state(interaction)
view._dispatch_item(item, interaction)
def is_message_tracked(self, message_id: int):
return message_id in self._synced_message_views
def remove_message_tracking(self, message_id: int) -> Optional[View]:
return self._synced_message_views.pop(message_id, None)
def update_from_message(self, message_id: int, components: List[ComponentPayload]):
# pre-req: is_message_tracked == true
view = self._synced_message_views[message_id]
view.refresh([_component_factory(d) for d in components])

View File

@@ -22,202 +22,140 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type, TypeVar, TYPE_CHECKING
import discord.abc import discord.abc
from .asset import Asset
from .colour import Colour
from .enums import DefaultAvatar
from .flags import PublicUserFlags from .flags import PublicUserFlags
from .utils import snowflake_time, _bytes_to_base64_data, MISSING from .utils import snowflake_time, _bytes_to_base64_data
from .enums import DefaultAvatar, try_enum
if TYPE_CHECKING: from .colour import Colour
from datetime import datetime from .asset import Asset
from .channel import DMChannel
from .guild import Guild
from .message import Message
from .state import ConnectionState
from .types.channel import DMChannel as DMChannelPayload
from .types.user import User as UserPayload
__all__ = ( __all__ = (
'User', 'User',
'ClientUser', 'ClientUser',
) )
BU = TypeVar('BU', bound='BaseUser') _BaseUser = discord.abc.User
class BaseUser(_BaseUser):
__slots__ = ('name', 'id', 'discriminator', 'avatar', 'bot', 'system', '_public_flags', '_state')
class _UserTag: def __init__(self, *, state, data):
__slots__ = ()
id: int
class BaseUser(_UserTag):
__slots__ = (
'name',
'id',
'discriminator',
'_avatar',
'_banner',
'_accent_colour',
'bot',
'system',
'_public_flags',
'_state',
)
if TYPE_CHECKING:
name: str
id: int
discriminator: str
bot: bool
system: bool
_state: ConnectionState
_avatar: Optional[str]
_banner: Optional[str]
_accent_colour: Optional[str]
_public_flags: int
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
self._state = state self._state = state
self._update(data) self._update(data)
def __repr__(self) -> str: def __str__(self):
return ( return '{0.name}#{0.discriminator}'.format(self)
f"<BaseUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}"
f" bot={self.bot} system={self.system}>"
)
def __str__(self) -> str: def __int__(self):
return f'{self.name}#{self.discriminator}' return self.id
def __eq__(self, other: Any) -> bool: def __eq__(self, other):
return isinstance(other, _UserTag) and other.id == self.id return isinstance(other, _BaseUser) and other.id == self.id
def __ne__(self, other: Any) -> bool: def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __hash__(self) -> int: def __hash__(self):
return self.id >> 22 return self.id >> 22
def _update(self, data: UserPayload) -> None: def _update(self, data):
self.name = data['username'] self.name = data['username']
self.id = int(data['id']) self.id = int(data['id'])
self.discriminator = data['discriminator'] self.discriminator = data['discriminator']
self._avatar = data['avatar'] self.avatar = data['avatar']
self._banner = data.get('banner', None)
self._accent_colour = data.get('accent_color', None)
self._public_flags = data.get('public_flags', 0) self._public_flags = data.get('public_flags', 0)
self.bot = data.get('bot', False) self.bot = data.get('bot', False)
self.system = data.get('system', False) self.system = data.get('system', False)
@classmethod @classmethod
def _copy(cls: Type[BU], user: BU) -> BU: def _copy(cls, user):
self = cls.__new__(cls) # bypass __init__ self = cls.__new__(cls) # bypass __init__
self.name = user.name self.name = user.name
self.id = user.id self.id = user.id
self.discriminator = user.discriminator self.discriminator = user.discriminator
self._avatar = user._avatar self.avatar = user.avatar
self._banner = user._banner
self._accent_colour = user._accent_colour
self.bot = user.bot self.bot = user.bot
self._state = user._state self._state = user._state
self._public_flags = user._public_flags self._public_flags = user._public_flags
return self return self
def _to_minimal_user_json(self) -> Dict[str, Any]: def _to_minimal_user_json(self):
return { return {
'username': self.name, 'username': self.name,
'id': self.id, 'id': self.id,
'avatar': self._avatar, 'avatar': self.avatar,
'discriminator': self.discriminator, 'discriminator': self.discriminator,
'bot': self.bot, 'bot': self.bot,
} }
@property @property
def public_flags(self) -> PublicUserFlags: def public_flags(self):
""":class:`PublicUserFlags`: The publicly available flags the user has.""" """:class:`PublicUserFlags`: The publicly available flags the user has."""
return PublicUserFlags._from_value(self._public_flags) return PublicUserFlags._from_value(self._public_flags)
@property @property
def avatar(self) -> Optional[Asset]: def avatar_url(self):
"""Optional[:class:`Asset`]: Returns an :class:`Asset` for the avatar the user has. """:class:`str`: Returns an direct url for the avatar the user has.
If the user does not have a traditional avatar, ``None`` is returned. If the user does not have a traditional avatar, an asset for
If you want the avatar that a user has displayed, consider :attr:`display_avatar`. the default avatar is returned instead.
""" """
if self._avatar is not None: return str(self.avatar_url_as(static_format="png", size=1024))
return Asset._from_avatar(self._state, self.id, self._avatar)
return None
@property def is_avatar_animated(self):
def default_avatar(self) -> Asset: """:class:`bool`: Indicates if the user has an animated avatar."""
""":class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator.""" return bool(self.avatar and self.avatar.startswith('a_'))
return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar))
@property def avatar_url_as(self, *, format=None, static_format='webp', size=1024):
def display_avatar(self) -> Asset: """Returns an :class:`Asset` for the avatar the user has.
""":class:`Asset`: Returns the user's display avatar.
For regular users this is just their default avatar or uploaded avatar. If the user does not have a traditional avatar, an asset for
the default avatar is returned instead.
.. versionadded:: 2.0 The format must be one of 'webp', 'jpeg', 'jpg', 'png' or 'gif', and
'gif' is only valid for animated avatars. The size must be a power of 2
between 16 and 4096.
Parameters
-----------
format: Optional[:class:`str`]
The format to attempt to convert the avatar to.
If the format is ``None``, then it is automatically
detected into either 'gif' or static_format depending on the
avatar being animated or not.
static_format: Optional[:class:`str`]
Format to attempt to convert only non-animated avatars to.
Defaults to 'webp'
size: :class:`int`
The size of the image to display.
Raises
------
InvalidArgument
Bad image format passed to ``format`` or ``static_format``, or
invalid ``size``.
Returns
--------
:class:`Asset`
The resulting CDN asset.
""" """
return self.avatar or self.default_avatar return Asset._from_avatar(self._state, self, format=format, static_format=static_format, size=size)
@property @property
def banner(self) -> Optional[Asset]: def default_avatar(self):
"""Optional[:class:`Asset`]: Returns the user's banner asset, if available. """:class:`DefaultAvatar`: Returns the default avatar for a given user. This is calculated by the user's discriminator."""
return try_enum(DefaultAvatar, int(self.discriminator) % len(DefaultAvatar))
.. versionadded:: 2.0
.. note::
This information is only available via :meth:`Client.fetch_user`.
"""
if self._banner is None:
return None
return Asset._from_user_banner(self._state, self.id, self._banner)
@property @property
def accent_colour(self) -> Optional[Colour]: def default_avatar_url(self):
"""Optional[:class:`Colour`]: Returns the user's accent colour, if applicable. """:class:`Asset`: Returns a URL for a user's default avatar."""
return Asset(self._state, f'/embed/avatars/{self.default_avatar.value}.png')
There is an alias for this named :attr:`accent_color`.
.. versionadded:: 2.0
.. note::
This information is only available via :meth:`Client.fetch_user`.
"""
if self._accent_colour is None:
return None
return Colour(self._accent_colour)
@property @property
def accent_color(self) -> Optional[Colour]: def colour(self):
"""Optional[:class:`Colour`]: Returns the user's accent color, if applicable.
There is an alias for this named :attr:`accent_colour`.
.. versionadded:: 2.0
.. note::
This information is only available via :meth:`Client.fetch_user`.
"""
return self.accent_colour
@property
def colour(self) -> Colour:
""":class:`Colour`: A property that returns a colour denoting the rendered colour """:class:`Colour`: A property that returns a colour denoting the rendered colour
for the user. This always returns :meth:`Colour.default`. for the user. This always returns :meth:`Colour.default`.
@@ -226,7 +164,7 @@ class BaseUser(_UserTag):
return Colour.default() return Colour.default()
@property @property
def color(self) -> Colour: def color(self):
""":class:`Colour`: A property that returns a color denoting the rendered color """:class:`Colour`: A property that returns a color denoting the rendered color
for the user. This always returns :meth:`Colour.default`. for the user. This always returns :meth:`Colour.default`.
@@ -235,12 +173,28 @@ class BaseUser(_UserTag):
return self.colour return self.colour
@property @property
def mention(self) -> str: def mention(self):
""":class:`str`: Returns a string that allows you to mention the given user.""" """:class:`str`: Returns a string that allows you to mention the given user."""
return f'<@{self.id}>' return f'<@{self.id}>'
def permissions_in(self, channel):
"""An alias for :meth:`abc.GuildChannel.permissions_for`.
Basically equivalent to:
.. code-block:: python3
channel.permissions_for(self)
Parameters
-----------
channel: :class:`abc.GuildChannel`
The channel to check your permissions for.
"""
return channel.permissions_for(self)
@property @property
def created_at(self) -> datetime: def created_at(self):
""":class:`datetime.datetime`: Returns the user's creation time in UTC. """:class:`datetime.datetime`: Returns the user's creation time in UTC.
This is when the user's Discord account was created. This is when the user's Discord account was created.
@@ -248,7 +202,7 @@ class BaseUser(_UserTag):
return snowflake_time(self.id) return snowflake_time(self.id)
@property @property
def display_name(self) -> str: def display_name(self):
""":class:`str`: Returns the user's display name. """:class:`str`: Returns the user's display name.
For regular users this is just their username, but For regular users this is just their username, but
@@ -257,7 +211,7 @@ class BaseUser(_UserTag):
""" """
return self.name return self.name
def mentioned_in(self, message: Message) -> bool: def mentioned_in(self, message):
"""Checks if the user is mentioned in the specified message. """Checks if the user is mentioned in the specified message.
Parameters Parameters
@@ -276,7 +230,6 @@ class BaseUser(_UserTag):
return any(user.id == self.id for user in message.mentions) return any(user.id == self.id for user in message.mentions)
class ClientUser(BaseUser): class ClientUser(BaseUser):
"""Represents your Discord user. """Represents your Discord user.
@@ -306,6 +259,8 @@ class ClientUser(BaseUser):
The user's unique ID. The user's unique ID.
discriminator: :class:`str` discriminator: :class:`str`
The user's discriminator. This is given when the username has conflicts. The user's discriminator. This is given when the username has conflicts.
avatar: Optional[:class:`str`]
The avatar hash the user has. Could be ``None``.
bot: :class:`bool` bot: :class:`bool`
Specifies if the user is a bot account. Specifies if the user is a bot account.
system: :class:`bool` system: :class:`bool`
@@ -314,31 +269,32 @@ class ClientUser(BaseUser):
.. versionadded:: 1.3 .. versionadded:: 1.3
verified: :class:`bool` verified: :class:`bool`
<<<<<<< HEAD
Specifies if the user's email is verified. Specifies if the user's email is verified.
email: Optional[:class:`str`]
The email the user used when registering.
.. deprecated:: 1.7
=======
Specifies if the user is a verified account.
>>>>>>> 523e35e4f3c3c49d4e471359f9fb559242bbecc8
locale: Optional[:class:`str`] locale: Optional[:class:`str`]
The IETF language tag used to identify the language the user is using. The IETF language tag used to identify the language the user is using.
mfa_enabled: :class:`bool` mfa_enabled: :class:`bool`
Specifies if the user has MFA turned on and working. Specifies if the user has MFA turned on and working.
""" """
__slots__ = BaseUser.__slots__ + \
('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__')
__slots__ = ('locale', '_flags', 'verified', 'mfa_enabled', '__weakref__') def __init__(self, *, state, data):
if TYPE_CHECKING:
verified: bool
locale: Optional[str]
mfa_enabled: bool
_flags: int
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None:
super().__init__(state=state, data=data) super().__init__(state=state, data=data)
def __repr__(self) -> str: def __repr__(self):
return ( return '<ClientUser id={0.id} name={0.name!r} discriminator={0.discriminator!r}' \
f'<ClientUser id={self.id} name={self.name!r} discriminator={self.discriminator!r}' ' bot={0.bot} verified={0.verified} mfa_enabled={0.mfa_enabled}>'.format(self)
f' bot={self.bot} verified={self.verified} mfa_enabled={self.mfa_enabled}>'
)
def _update(self, data: UserPayload) -> None: def _update(self, data):
super()._update(data) super()._update(data)
# There's actually an Optional[str] phone field as well but I won't use it # There's actually an Optional[str] phone field as well but I won't use it
self.verified = data.get('verified', False) self.verified = data.get('verified', False)
@@ -346,7 +302,8 @@ class ClientUser(BaseUser):
self._flags = data.get('flags', 0) self._flags = data.get('flags', 0)
self.mfa_enabled = data.get('mfa_enabled', False) self.mfa_enabled = data.get('mfa_enabled', False)
async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser:
async def edit(self, *, username=None, avatar=None):
"""|coro| """|coro|
Edits the current profile of the client. Edits the current profile of the client.
@@ -360,9 +317,6 @@ class ClientUser(BaseUser):
The only image formats supported for uploading is JPEG and PNG. The only image formats supported for uploading is JPEG and PNG.
.. versionchanged:: 2.0
The edit is no longer in-place, instead the newly edited client user is returned.
Parameters Parameters
----------- -----------
username: :class:`str` username: :class:`str`
@@ -377,22 +331,13 @@ class ClientUser(BaseUser):
Editing your profile failed. Editing your profile failed.
InvalidArgument InvalidArgument
Wrong image format passed for ``avatar``. Wrong image format passed for ``avatar``.
Returns
---------
:class:`ClientUser`
The newly edited client user.
""" """
payload: Dict[str, Any] = {}
if username is not MISSING:
payload['username'] = username
if avatar is not MISSING: if avatar is not None:
payload['avatar'] = _bytes_to_base64_data(avatar) avatar = _bytes_to_base64_data(avatar)
data: UserPayload = await self._state.http.edit_profile(payload)
return ClientUser(state=self._state, data=data)
data = await self._state.http.edit_profile(username=username, avatar=avatar)
self._update(data)
class User(BaseUser, discord.abc.Messageable): class User(BaseUser, discord.abc.Messageable):
"""Represents a Discord user. """Represents a Discord user.
@@ -423,40 +368,25 @@ class User(BaseUser, discord.abc.Messageable):
The user's unique ID. The user's unique ID.
discriminator: :class:`str` discriminator: :class:`str`
The user's discriminator. This is given when the username has conflicts. The user's discriminator. This is given when the username has conflicts.
avatar: Optional[:class:`str`]
The avatar hash the user has. Could be None.
bot: :class:`bool` bot: :class:`bool`
Specifies if the user is a bot account. Specifies if the user is a bot account.
system: :class:`bool` system: :class:`bool`
Specifies if the user is a system user (i.e. represents Discord officially). Specifies if the user is a system user (i.e. represents Discord officially).
""" """
__slots__ = ('_stored',) __slots__ = BaseUser.__slots__ + ('__weakref__',)
def __init__(self, *, state: ConnectionState, data: UserPayload) -> None: def __repr__(self):
super().__init__(state=state, data=data) return '<User id={0.id} name={0.name!r} discriminator={0.discriminator!r} bot={0.bot}>'.format(self)
self._stored: bool = False
def __repr__(self) -> str: async def _get_channel(self):
return f'<User id={self.id} name={self.name!r} discriminator={self.discriminator!r} bot={self.bot}>'
def __del__(self) -> None:
try:
if self._stored:
self._state.deref_user(self.id)
except Exception:
pass
@classmethod
def _copy(cls, user: User):
self = super()._copy(user)
self._stored = False
return self
async def _get_channel(self) -> DMChannel:
ch = await self.create_dm() ch = await self.create_dm()
return ch return ch
@property @property
def dm_channel(self) -> Optional[DMChannel]: def dm_channel(self):
"""Optional[:class:`DMChannel`]: Returns the channel associated with this user if it exists. """Optional[:class:`DMChannel`]: Returns the channel associated with this user if it exists.
If this returns ``None``, you can create a DM channel by calling the If this returns ``None``, you can create a DM channel by calling the
@@ -465,7 +395,7 @@ class User(BaseUser, discord.abc.Messageable):
return self._state._get_private_channel_by_user(self.id) return self._state._get_private_channel_by_user(self.id)
@property @property
def mutual_guilds(self) -> List[Guild]: def mutual_guilds(self):
"""List[:class:`Guild`]: The guilds that the user shares with the client. """List[:class:`Guild`]: The guilds that the user shares with the client.
.. note:: .. note::
@@ -476,7 +406,7 @@ class User(BaseUser, discord.abc.Messageable):
""" """
return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)] return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)]
async def create_dm(self) -> DMChannel: async def create_dm(self):
"""|coro| """|coro|
Creates a :class:`DMChannel` with this user. Creates a :class:`DMChannel` with this user.
@@ -494,5 +424,5 @@ class User(BaseUser, discord.abc.Messageable):
return found return found
state = self._state state = self._state
data: DMChannelPayload = await state.http.start_private_message(self.id) data = await state.http.start_private_message(self.id)
return state.add_dm_channel(data) return state.add_dm_channel(data)

View File

@@ -21,33 +21,11 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
import array import array
import asyncio import asyncio
import collections.abc import collections.abc
from typing import ( from typing import Optional, overload
Any,
AsyncIterator,
Callable,
Dict,
ForwardRef,
Generic,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
TYPE_CHECKING,
)
import unicodedata import unicodedata
from base64 import b64encode from base64 import b64encode
from bisect import bisect_left from bisect import bisect_left
@@ -57,22 +35,12 @@ from inspect import isawaitable as _isawaitable, signature as _signature
from operator import attrgetter from operator import attrgetter
import json import json
import re import re
import sys
import types
import warnings import warnings
from .errors import InvalidArgument from .errors import InvalidArgument
try:
import orjson
except ModuleNotFoundError:
HAS_ORJSON = False
else:
HAS_ORJSON = True
__all__ = ( __all__ = (
'oauth_url', 'oauth_uri',
'snowflake_time', 'snowflake_time',
'time_snowflake', 'time_snowflake',
'find', 'find',
@@ -82,28 +50,10 @@ __all__ = (
'remove_markdown', 'remove_markdown',
'escape_markdown', 'escape_markdown',
'escape_mentions', 'escape_mentions',
'as_chunks',
'format_dt',
) )
DISCORD_EPOCH = 1420070400000 DISCORD_EPOCH = 1420070400000
class cached_property:
class _MissingSentinel:
def __eq__(self, other):
return False
def __bool__(self):
return False
def __repr__(self):
return '...'
MISSING: Any = _MissingSentinel()
class _cached_property:
def __init__(self, function): def __init__(self, function):
self.function = function self.function = function
self.__doc__ = getattr(function, '__doc__') self.__doc__ = getattr(function, '__doc__')
@@ -117,47 +67,13 @@ class _cached_property:
return value return value
class CachedSlotProperty:
if TYPE_CHECKING: def __init__(self, name, function):
from functools import cached_property as cached_property
from typing_extensions import ParamSpec
from .permissions import Permissions
from .abc import Snowflake
from .invite import Invite
from .template import Template
class _RequestLike(Protocol):
headers: Mapping[str, Any]
P = ParamSpec('P')
else:
cached_property = _cached_property
T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
_Iter = Union[Iterator[T], AsyncIterator[T]]
class CachedSlotProperty(Generic[T, T_co]):
def __init__(self, name: str, function: Callable[[T], T_co]) -> None:
self.name = name self.name = name
self.function = function self.function = function
self.__doc__ = getattr(function, '__doc__') self.__doc__ = getattr(function, '__doc__')
@overload def __get__(self, instance, owner):
def __get__(self, instance: None, owner: Type[T]) -> CachedSlotProperty[T, T_co]:
...
@overload
def __get__(self, instance: T, owner: Type[T]) -> T_co:
...
def __get__(self, instance: Optional[T], owner: Type[T]) -> Any:
if instance is None: if instance is None:
return self return self
@@ -168,122 +84,85 @@ class CachedSlotProperty(Generic[T, T_co]):
setattr(instance, self.name, value) setattr(instance, self.name, value)
return value return value
def cached_slot_property(name):
class classproperty(Generic[T_co]): def decorator(func):
def __init__(self, fget: Callable[[Any], T_co]) -> None:
self.fget = fget
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
return self.fget(owner)
def __set__(self, instance, value) -> None:
raise AttributeError('cannot set attribute')
def cached_slot_property(name: str) -> Callable[[Callable[[T], T_co]], CachedSlotProperty[T, T_co]]:
def decorator(func: Callable[[T], T_co]) -> CachedSlotProperty[T, T_co]:
return CachedSlotProperty(name, func) return CachedSlotProperty(name, func)
return decorator return decorator
class SequenceProxy(collections.abc.Sequence):
class SequenceProxy(Generic[T_co], collections.abc.Sequence):
"""Read-only proxy of a Sequence.""" """Read-only proxy of a Sequence."""
def __init__(self, proxied):
def __init__(self, proxied: Sequence[T_co]):
self.__proxied = proxied self.__proxied = proxied
def __getitem__(self, idx: int) -> T_co: def __getitem__(self, idx):
return self.__proxied[idx] return self.__proxied[idx]
def __len__(self) -> int: def __len__(self):
return len(self.__proxied) return len(self.__proxied)
def __contains__(self, item: Any) -> bool: def __contains__(self, item):
return item in self.__proxied return item in self.__proxied
def __iter__(self) -> Iterator[T_co]: def __iter__(self):
return iter(self.__proxied) return iter(self.__proxied)
def __reversed__(self) -> Iterator[T_co]: def __reversed__(self):
return reversed(self.__proxied) return reversed(self.__proxied)
def index(self, value: Any, *args, **kwargs) -> int: def index(self, value, *args, **kwargs):
return self.__proxied.index(value, *args, **kwargs) return self.__proxied.index(value, *args, **kwargs)
def count(self, value: Any) -> int: def count(self, value):
return self.__proxied.count(value) return self.__proxied.count(value)
@overload @overload
def parse_time(timestamp: None) -> None: def parse_time(timestamp: None) -> None:
... ...
@overload @overload
def parse_time(timestamp: str) -> datetime.datetime: def parse_time(timestamp: str) -> datetime.datetime:
... ...
@overload
def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
...
def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]: def parse_time(timestamp: Optional[str]) -> Optional[datetime.datetime]:
if timestamp: if timestamp:
return datetime.datetime.fromisoformat(timestamp) return datetime.datetime.fromisoformat(timestamp)
return None return None
def copy_doc(original):
def copy_doc(original: Callable) -> Callable[[T], T]: def decorator(overriden):
def decorator(overriden: T) -> T:
overriden.__doc__ = original.__doc__ overriden.__doc__ = original.__doc__
overriden.__signature__ = _signature(original) # type: ignore overriden.__signature__ = _signature(original)
return overriden return overriden
return decorator return decorator
def deprecated(instead=None):
def deprecated(instead: Optional[str] = None) -> Callable[[Callable[P, T]], Callable[P, T]]: def actual_decorator(func):
def actual_decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func) @functools.wraps(func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> T: def decorated(*args, **kwargs):
warnings.simplefilter('always', DeprecationWarning) # turn off filter warnings.simplefilter('always', DeprecationWarning) # turn off filter
if instead: if instead:
fmt = "{0.__name__} is deprecated, use {1} instead." fmt = "{0.__name__} is deprecated, use {1} instead."
else: else:
fmt = '{0.__name__} is deprecated.' fmt = '{0.__name__} is deprecated.'
warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning) warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning)
warnings.simplefilter('default', DeprecationWarning) # reset filter warnings.simplefilter('default', DeprecationWarning) # reset filter
return func(*args, **kwargs) return func(*args, **kwargs)
return decorated return decorated
return actual_decorator return actual_decorator
def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None, scopes=None):
def oauth_url(
client_id: Union[int, str],
*,
permissions: Permissions = MISSING,
guild: Snowflake = MISSING,
redirect_uri: str = MISSING,
scopes: Iterable[str] = MISSING,
disable_guild_select: bool = False,
) -> str:
"""A helper function that returns the OAuth2 URL for inviting the bot """A helper function that returns the OAuth2 URL for inviting the bot
into guilds. into guilds.
Parameters Parameters
----------- -----------
client_id: Union[:class:`int`, :class:`str`] client_id: :class:`str`
The client ID for your bot. The client ID for your bot.
permissions: :class:`~discord.Permissions` permissions: :class:`~discord.Permissions`
The permissions you're requesting. If not given then you won't be requesting any The permissions you're requesting. If not given then you won't be requesting any
permissions. permissions.
guild: :class:`~discord.abc.Snowflake` guild: :class:`~discord.Guild`
The guild to pre-select in the authorization screen, if available. The guild to pre-select in the authorization screen, if available.
redirect_uri: :class:`str` redirect_uri: :class:`str`
An optional valid redirect URI. An optional valid redirect URI.
@@ -291,10 +170,6 @@ def oauth_url(
An optional valid list of scopes. Defaults to ``('bot',)``. An optional valid list of scopes. Defaults to ``('bot',)``.
.. versionadded:: 1.7 .. versionadded:: 1.7
disable_guild_select: :class:`bool`
Whether to disallow the user from changing the guild dropdown.
.. versionadded:: 2.0
Returns Returns
-------- --------
@@ -302,17 +177,14 @@ def oauth_url(
The OAuth2 URL for inviting the bot into guilds. The OAuth2 URL for inviting the bot into guilds.
""" """
url = f'https://discord.com/oauth2/authorize?client_id={client_id}' url = f'https://discord.com/oauth2/authorize?client_id={client_id}'
url += '&scope=' + '+'.join(scopes or ('bot',)) url = url + '&scope=' + '+'.join(scopes or ('bot',))
if permissions is not MISSING: if permissions is not None:
url += f'&permissions={permissions.value}' url = url + '&permissions=' + str(permissions.value)
if guild is not MISSING: if guild is not None:
url += f'&guild_id={guild.id}' url = url + "&guild_id=" + str(guild.id)
if redirect_uri is not MISSING: if redirect_uri is not None:
from urllib.parse import urlencode from urllib.parse import urlencode
url = url + "&response_type=code&" + urlencode({'redirect_uri': redirect_uri})
url += '&response_type=code&' + urlencode({'redirect_uri': redirect_uri})
if disable_guild_select:
url += '&disable_guild_select=true'
return url return url
@@ -329,8 +201,7 @@ def snowflake_time(id: int) -> datetime.datetime:
An aware datetime in UTC representing the creation time of the snowflake. An aware datetime in UTC representing the creation time of the snowflake.
""" """
timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000 timestamp = ((id >> 22) + DISCORD_EPOCH) / 1000
return datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc) return datetime.datetime.utcfromtimestamp(timestamp).replace(tzinfo=datetime.timezone.utc)
def time_snowflake(dt: datetime.datetime, high: bool = False) -> int: def time_snowflake(dt: datetime.datetime, high: bool = False) -> int:
"""Returns a numeric snowflake pretending to be created at the given date. """Returns a numeric snowflake pretending to be created at the given date.
@@ -355,10 +226,9 @@ def time_snowflake(dt: datetime.datetime, high: bool = False) -> int:
The snowflake representing the time given. The snowflake representing the time given.
""" """
discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH) discord_millis = int(dt.timestamp() * 1000 - DISCORD_EPOCH)
return (discord_millis << 22) + (2 ** 22 - 1 if high else 0) return (discord_millis << 22) + (2**22-1 if high else 0)
def find(predicate, seq):
def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
"""A helper to return the first element found in the sequence """A helper to return the first element found in the sequence
that meets the predicate. For example: :: that meets the predicate. For example: ::
@@ -374,7 +244,7 @@ def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
----------- -----------
predicate predicate
A function that returns a boolean-like result. A function that returns a boolean-like result.
seq: :class:`collections.abc.Iterable` seq: iterable
The iterable to search through. The iterable to search through.
""" """
@@ -383,8 +253,7 @@ def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
return element return element
return None return None
def get(iterable, **attrs):
def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
r"""A helper that returns the first element in the iterable that meets r"""A helper that returns the first element in the iterable that meets
all the traits passed in ``attrs``. This is an alternative for all the traits passed in ``attrs``. This is an alternative for
:func:`~discord.utils.find`. :func:`~discord.utils.find`.
@@ -441,19 +310,22 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]:
return elem return elem
return None return None
converted = [(attrget(attr.replace('__', '.')), value) for attr, value in attrs.items()] converted = [
(attrget(attr.replace('__', '.')), value)
for attr, value in attrs.items()
]
for elem in iterable: for elem in iterable:
if _all(pred(elem) == value for pred, value in converted): if _all(pred(elem) == value for pred, value in converted):
return elem return elem
return None return None
def _unique(iterable):
seen = set()
adder = seen.add
return [x for x in iterable if not (x in seen or adder(x))]
def _unique(iterable: Iterable[T]) -> List[T]: def _get_as_snowflake(data, key):
return [x for x in dict.fromkeys(iterable)]
def _get_as_snowflake(data: Any, key: str) -> Optional[int]:
try: try:
value = data[key] value = data[key]
except KeyError: except KeyError:
@@ -461,8 +333,7 @@ def _get_as_snowflake(data: Any, key: str) -> Optional[int]:
else: else:
return value and int(value) return value and int(value)
def _get_mime_type_for_image(data):
def _get_mime_type_for_image(data: bytes):
if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'): if data.startswith(b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'):
return 'image/png' return 'image/png'
elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'): elif data[0:3] == b'\xff\xd8\xff' or data[6:10] in (b'JFIF', b'Exif'):
@@ -474,31 +345,17 @@ def _get_mime_type_for_image(data: bytes):
else: else:
raise InvalidArgument('Unsupported image type given') raise InvalidArgument('Unsupported image type given')
def _bytes_to_base64_data(data):
def _bytes_to_base64_data(data: bytes) -> str:
fmt = 'data:{mime};base64,{data}' fmt = 'data:{mime};base64,{data}'
mime = _get_mime_type_for_image(data) mime = _get_mime_type_for_image(data)
b64 = b64encode(data).decode('ascii') b64 = b64encode(data).decode('ascii')
return fmt.format(mime=mime, data=b64) return fmt.format(mime=mime, data=b64)
def to_json(obj):
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
if HAS_ORJSON: def _parse_ratelimit_header(request, *, use_clock=False):
reset_after = request.headers.get('X-Ratelimit-Reset-After')
def _to_json(obj: Any) -> str: # type: ignore
return orjson.dumps(obj).decode('utf-8')
_from_json = orjson.loads # type: ignore
else:
def _to_json(obj: Any) -> str:
return json.dumps(obj, separators=(',', ':'), ensure_ascii=True)
_from_json = json.loads
def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After')
if use_clock or not reset_after: if use_clock or not reset_after:
utc = datetime.timezone.utc utc = datetime.timezone.utc
now = datetime.datetime.now(utc) now = datetime.datetime.now(utc)
@@ -507,7 +364,6 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
else: else:
return float(reset_after) return float(reset_after)
async def maybe_coroutine(f, *args, **kwargs): async def maybe_coroutine(f, *args, **kwargs):
value = f(*args, **kwargs) value = f(*args, **kwargs)
if _isawaitable(value): if _isawaitable(value):
@@ -515,7 +371,6 @@ async def maybe_coroutine(f, *args, **kwargs):
else: else:
return value return value
async def async_all(gen, *, check=_isawaitable): async def async_all(gen, *, check=_isawaitable):
for elem in gen: for elem in gen:
if check(elem): if check(elem):
@@ -524,9 +379,10 @@ async def async_all(gen, *, check=_isawaitable):
return False return False
return True return True
async def sane_wait_for(futures, *, timeout): async def sane_wait_for(futures, *, timeout):
ensured = [asyncio.ensure_future(fut) for fut in futures] ensured = [
asyncio.ensure_future(fut) for fut in futures
]
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
if len(pending) != 0: if len(pending) != 0:
@@ -534,23 +390,7 @@ async def sane_wait_for(futures, *, timeout):
return done return done
async def sleep_until(when, result=None):
def get_slots(cls: Type[Any]) -> Iterator[str]:
for mro in reversed(cls.__mro__):
try:
yield from mro.__slots__
except AttributeError:
continue
def compute_timedelta(dt: datetime.datetime):
if dt.tzinfo is None:
dt = dt.astimezone()
now = datetime.datetime.now(datetime.timezone.utc)
return max((dt - now).total_seconds(), 0)
async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Optional[T]:
"""|coro| """|coro|
Sleep until a specified time. Sleep until a specified time.
@@ -567,14 +407,16 @@ async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Op
result: Any result: Any
If provided is returned to the caller when the coroutine completes. If provided is returned to the caller when the coroutine completes.
""" """
delta = compute_timedelta(when) if when.tzinfo is None:
return await asyncio.sleep(delta, result) when = when.astimezone()
now = datetime.datetime.now(datetime.timezone.utc)
delta = (when - now).total_seconds()
return await asyncio.sleep(max(delta, 0), result)
def utcnow() -> datetime.datetime: def utcnow() -> datetime.datetime:
"""A helper function to return an aware UTC datetime representing the current time. """A helper function to return an aware UTC datetime representing the current time.
This should be preferred to :meth:`datetime.datetime.utcnow` since it is an aware This should be preferred to :func:`datetime.datetime.utcnow` since it is an aware
datetime, compared to the naive datetime in the standard library. datetime, compared to the naive datetime in the standard library.
.. versionadded:: 2.0 .. versionadded:: 2.0
@@ -586,11 +428,9 @@ def utcnow() -> datetime.datetime:
""" """
return datetime.datetime.now(datetime.timezone.utc) return datetime.datetime.now(datetime.timezone.utc)
def valid_icon_size(size):
def valid_icon_size(size: int) -> bool:
"""Icons must be power of 2 within [16, 4096].""" """Icons must be power of 2 within [16, 4096]."""
return not size & (size - 1) and 4096 >= size >= 16 return not size & (size - 1) and size in range(16, 4097)
class SnowflakeList(array.array): class SnowflakeList(array.array):
"""Internal data storage class to efficiently store a list of snowflakes. """Internal data storage class to efficiently store a list of snowflakes.
@@ -606,31 +446,24 @@ class SnowflakeList(array.array):
__slots__ = () __slots__ = ()
if TYPE_CHECKING: def __new__(cls, data, *, is_sorted=False):
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data))
def __init__(self, data: Iterable[int], *, is_sorted: bool = False): def add(self, element):
...
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
def add(self, element: int) -> None:
i = bisect_left(self, element) i = bisect_left(self, element)
self.insert(i, element) self.insert(i, element)
def get(self, element: int) -> Optional[int]: def get(self, element):
i = bisect_left(self, element) i = bisect_left(self, element)
return self[i] if i != len(self) and self[i] == element else None return self[i] if i != len(self) and self[i] == element else None
def has(self, element: int) -> bool: def has(self, element):
i = bisect_left(self, element) i = bisect_left(self, element)
return i != len(self) and self[i] == element return i != len(self) and self[i] == element
_IS_ASCII = re.compile(r'^[\x00-\x7f]+$') _IS_ASCII = re.compile(r'^[\x00-\x7f]+$')
def _string_width(string, *, _IS_ASCII=_IS_ASCII):
def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
"""Returns string's width.""" """Returns string's width."""
match = _IS_ASCII.match(string) match = _IS_ASCII.match(string)
if match: if match:
@@ -640,8 +473,7 @@ def _string_width(string: str, *, _IS_ASCII=_IS_ASCII) -> int:
func = unicodedata.east_asian_width func = unicodedata.east_asian_width
return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string) return sum(2 if func(char) in UNICODE_WIDE_CHAR_TYPE else 1 for char in string)
def resolve_invite(invite):
def resolve_invite(invite: Union[Invite, str]) -> str:
""" """
Resolves an invite from a :class:`~discord.Invite`, URL or code. Resolves an invite from a :class:`~discord.Invite`, URL or code.
@@ -656,7 +488,6 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
The invite code. The invite code.
""" """
from .invite import Invite # circular import from .invite import Invite # circular import
if isinstance(invite, Invite): if isinstance(invite, Invite):
return invite.code return invite.code
else: else:
@@ -666,8 +497,7 @@ def resolve_invite(invite: Union[Invite, str]) -> str:
return m.group(1) return m.group(1)
return invite return invite
def resolve_template(code):
def resolve_template(code: Union[Template, str]) -> str:
""" """
Resolves a template code from a :class:`~discord.Template`, URL or code. Resolves a template code from a :class:`~discord.Template`, URL or code.
@@ -683,8 +513,7 @@ def resolve_template(code: Union[Template, str]) -> str:
:class:`str` :class:`str`
The template code. The template code.
""" """
from .template import Template # circular import from .template import Template # circular import
if isinstance(code, Template): if isinstance(code, Template):
return code.code return code.code
else: else:
@@ -694,8 +523,8 @@ def resolve_template(code: Union[Template, str]) -> str:
return m.group(1) return m.group(1)
return code return code
_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c)
_MARKDOWN_ESCAPE_SUBREGEX = '|'.join(r'\{0}(?=([\s\S]*((?<!\{0})\{0})))'.format(c) for c in ('*', '`', '_', '~', '|')) for c in ('*', '`', '_', '~', '|'))
_MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)' _MARKDOWN_ESCAPE_COMMON = r'^>(?:>>)?\s|\[.+\]\(.+\)'
@@ -705,8 +534,7 @@ _URL_REGEX = r'(?P<url><[^: >]+:\/[^ >]+>|(?:https?|steam):\/\/[^\s<]+[^<.,:;\"\
_MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})' _MARKDOWN_STOCK_REGEX = fr'(?P<markdown>[_\\~|\*`]|{_MARKDOWN_ESCAPE_COMMON})'
def remove_markdown(text, *, ignore_links=True):
def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
"""A helper function that removes markdown characters. """A helper function that removes markdown characters.
.. versionadded:: 1.7 .. versionadded:: 1.7
@@ -739,8 +567,7 @@ def remove_markdown(text: str, *, ignore_links: bool = True) -> str:
regex = f'(?:{_URL_REGEX}|{regex})' regex = f'(?:{_URL_REGEX}|{regex})'
return re.sub(regex, replacement, text, 0, re.MULTILINE) return re.sub(regex, replacement, text, 0, re.MULTILINE)
def escape_markdown(text, *, as_needed=False, ignore_links=True):
def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str:
r"""A helper function that escapes Discord's markdown. r"""A helper function that escapes Discord's markdown.
Parameters Parameters
@@ -766,7 +593,6 @@ def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool =
""" """
if not as_needed: if not as_needed:
def replacement(match): def replacement(match):
groupdict = match.groupdict() groupdict = match.groupdict()
is_url = groupdict.get('url') is_url = groupdict.get('url')
@@ -782,8 +608,7 @@ def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool =
text = re.sub(r'\\', r'\\\\', text) text = re.sub(r'\\', r'\\\\', text)
return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text) return _MARKDOWN_ESCAPE_REGEX.sub(r'\\\1', text)
def escape_mentions(text):
def escape_mentions(text: str) -> str:
"""A helper function that escapes everyone, here, role, and user mentions. """A helper function that escapes everyone, here, role, and user mentions.
.. note:: .. note::
@@ -807,213 +632,3 @@ def escape_mentions(text: str) -> str:
The text with the mentions removed. The text with the mentions removed.
""" """
return re.sub(r'@(everyone|here|[!&]?[0-9]{17,20})', '@\u200b\\1', text) return re.sub(r'@(everyone|here|[!&]?[0-9]{17,20})', '@\u200b\\1', text)
def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
ret = []
n = 0
for item in iterator:
ret.append(item)
n += 1
if n == max_size:
yield ret
ret = []
n = 0
if ret:
yield ret
async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]:
ret = []
n = 0
async for item in iterator:
ret.append(item)
n += 1
if n == max_size:
yield ret
ret = []
n = 0
if ret:
yield ret
@overload
def as_chunks(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
...
@overload
def as_chunks(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]:
...
def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
"""A helper function that collects an iterator into chunks of a given size.
.. versionadded:: 2.0
Parameters
----------
iterator: Union[:class:`collections.abc.Iterator`, :class:`collections.abc.AsyncIterator`]
The iterator to chunk, can be sync or async.
max_size: :class:`int`
The maximum chunk size.
.. warning::
The last chunk collected may not be as large as ``max_size``.
Returns
--------
Union[:class:`Iterator`, :class:`AsyncIterator`]
A new iterator which yields chunks of a given size.
"""
if max_size <= 0:
raise ValueError('Chunk sizes must be greater than 0.')
if isinstance(iterator, AsyncIterator):
return _achunk(iterator, max_size)
return _chunk(iterator, max_size)
PY_310 = sys.version_info >= (3, 10)
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
params = []
literal_cls = type(Literal[0])
for p in parameters:
if isinstance(p, literal_cls):
params.extend(p.__args__)
else:
params.append(p)
return tuple(params)
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
none_cls = type(None)
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
def evaluate_annotation(
tp: Any,
globals: Dict[str, Any],
locals: Dict[str, Any],
cache: Dict[str, Any],
*,
implicit_str: bool = True,
):
if isinstance(tp, ForwardRef):
tp = tp.__forward_arg__
# ForwardRefs always evaluate their internals
implicit_str = True
if implicit_str and isinstance(tp, str):
if tp in cache:
return cache[tp]
evaluated = eval(tp, globals, locals)
cache[tp] = evaluated
return evaluate_annotation(evaluated, globals, locals, cache)
if hasattr(tp, '__args__'):
implicit_str = True
is_literal = False
args = tp.__args__
if not hasattr(tp, '__origin__'):
if PY_310 and tp.__class__ is types.UnionType: # type: ignore
converted = Union[args] # type: ignore
return evaluate_annotation(converted, globals, locals, cache)
return tp
if tp.__origin__ is Union:
try:
if args.index(type(None)) != len(args) - 1:
args = normalise_optional_params(tp.__args__)
except ValueError:
pass
if tp.__origin__ is Literal:
if not PY_310:
args = flatten_literal_params(tp.__args__)
implicit_str = False
is_literal = True
evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
raise TypeError('Literal arguments must be of type str, int, bool, or NoneType.')
if evaluated_args == args:
return tp
try:
return tp.copy_with(evaluated_args)
except AttributeError:
return tp.__origin__[evaluated_args]
return tp
def resolve_annotation(
annotation: Any,
globalns: Dict[str, Any],
localns: Optional[Dict[str, Any]],
cache: Optional[Dict[str, Any]],
) -> Any:
if annotation is None:
return type(None)
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
locals = globalns if localns is None else localns
if cache is None:
cache = {}
return evaluate_annotation(annotation, globalns, locals, cache)
TimestampStyle = Literal['f', 'F', 'd', 'D', 't', 'T', 'R']
def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) -> str:
"""A helper function to format a :class:`datetime.datetime` for presentation within Discord.
This allows for a locale-independent way of presenting data using Discord specific Markdown.
+-------------+----------------------------+-----------------+
| Style | Example Output | Description |
+=============+============================+=================+
| t | 22:57 | Short Time |
+-------------+----------------------------+-----------------+
| T | 22:57:58 | Long Time |
+-------------+----------------------------+-----------------+
| d | 17/05/2016 | Short Date |
+-------------+----------------------------+-----------------+
| D | 17 May 2016 | Long Date |
+-------------+----------------------------+-----------------+
| f (default) | 17 May 2016 22:57 | Short Date Time |
+-------------+----------------------------+-----------------+
| F | Tuesday, 17 May 2016 22:57 | Long Date Time |
+-------------+----------------------------+-----------------+
| R | 5 years ago | Relative Time |
+-------------+----------------------------+-----------------+
Note that the exact output depends on the user's locale setting in the client. The example output
presented is using the ``en-GB`` locale.
.. versionadded:: 2.0
Parameters
-----------
dt: :class:`datetime.datetime`
The datetime to format.
style: :class:`str`
The style to format the datetime with.
Returns
--------
:class:`str`
The formatted string.
"""
if style is None:
return f'<t:{int(dt.timestamp())}>'
return f'<t:{int(dt.timestamp())}:{style}>'

Some files were not shown because too many files have changed in this diff Show More