Add SKU subscriptions support

This commit is contained in:
MCausc78
2024-10-10 01:04:14 +03:00
committed by GitHub
parent 0ce75f3f53
commit 58b6929aa5
10 changed files with 445 additions and 12 deletions

View File

@ -25,16 +25,18 @@ DEALINGS IN THE SOFTWARE.
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
from typing import AsyncIterator, Optional, TYPE_CHECKING
from . import utils
from .errors import MissingApplicationID
from .enums import try_enum, SKUType, EntitlementType
from .flags import SKUFlags
from .object import Object
from .subscription import Subscription
if TYPE_CHECKING:
from datetime import datetime
from .abc import SnowflakeTime, Snowflake
from .guild import Guild
from .state import ConnectionState
from .types.sku import (
@ -100,6 +102,149 @@ class SKU:
""":class:`datetime.datetime`: Returns the sku's creation time in UTC."""
return utils.snowflake_time(self.id)
async def fetch_subscription(self, subscription_id: int, /) -> Subscription:
"""|coro|
Retrieves a :class:`.Subscription` with the specified ID.
.. versionadded:: 2.5
Parameters
-----------
subscription_id: :class:`int`
The subscription's ID to fetch from.
Raises
-------
NotFound
An subscription with this ID does not exist.
HTTPException
Fetching the subscription failed.
Returns
--------
:class:`.Subscription`
The subscription you requested.
"""
data = await self._state.http.get_sku_subscription(self.id, subscription_id)
return Subscription(data=data, state=self._state)
async def subscriptions(
self,
*,
limit: Optional[int] = 50,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
user: Snowflake,
) -> AsyncIterator[Subscription]:
"""Retrieves an :term:`asynchronous iterator` of the :class:`.Subscription` that SKU has.
.. versionadded:: 2.5
Examples
---------
Usage ::
async for subscription in sku.subscriptions(limit=100):
print(subscription.user_id, subscription.current_period_end)
Flattening into a list ::
subscriptions = [subscription async for subscription in sku.subscriptions(limit=100)]
# subscriptions is now a list of Subscription...
All parameters are optional.
Parameters
-----------
limit: Optional[:class:`int`]
The number of subscriptions to retrieve. If ``None``, it retrieves every subscription for this SKU.
Note, however, that this would make it a slow operation. Defaults to ``100``.
before: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]]
Retrieve subscriptions before this date or entitlement.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
after: Optional[Union[:class:`~discord.abc.Snowflake`, :class:`datetime.datetime`]]
Retrieve subscriptions after this date or entitlement.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
user: :class:`~discord.abc.Snowflake`
The user to filter by.
Raises
-------
HTTPException
Fetching the subscriptions failed.
TypeError
Both ``after`` and ``before`` were provided, as Discord does not
support this type of pagination.
Yields
--------
:class:`.Subscription`
The subscription with the SKU.
"""
if before is not None and after is not None:
raise TypeError('subscriptions pagination does not support both before and after')
# This endpoint paginates in ascending order.
endpoint = self._state.http.list_sku_subscriptions
async def _before_strategy(retrieve: int, before: Optional[Snowflake], limit: Optional[int]):
before_id = before.id if before else None
data = await endpoint(self.id, before=before_id, limit=retrieve, user_id=user.id)
if data:
if limit is not None:
limit -= len(data)
before = Object(id=int(data[0]['id']))
return data, before, limit
async def _after_strategy(retrieve: int, after: Optional[Snowflake], limit: Optional[int]):
after_id = after.id if after else None
data = await endpoint(
self.id,
after=after_id,
limit=retrieve,
user_id=user.id,
)
if data:
if limit is not None:
limit -= len(data)
after = Object(id=int(data[-1]['id']))
return data, after, limit
if isinstance(before, datetime):
before = Object(id=utils.time_snowflake(before, high=False))
if isinstance(after, datetime):
after = Object(id=utils.time_snowflake(after, high=True))
if before:
strategy, state = _before_strategy, before
else:
strategy, state = _after_strategy, after
while True:
retrieve = 100 if limit is None else min(limit, 100)
if retrieve < 1:
return
data, state, limit = await strategy(retrieve, state, limit)
# Terminate loop on next iteration; there's no data left after this
if len(data) < 1000:
limit = 0
for e in data:
yield Subscription(data=e, state=self._state)
class Entitlement:
"""Represents an entitlement from user or guild which has been granted access to a premium offering.
@ -190,17 +335,12 @@ class Entitlement:
Raises
-------
MissingApplicationID
The application ID could not be found.
NotFound
The entitlement could not be found.
HTTPException
Consuming the entitlement failed.
"""
if self.application_id is None:
raise MissingApplicationID
await self._state.http.consume_entitlement(self.application_id, self.id)
async def delete(self) -> None:
@ -210,15 +350,10 @@ class Entitlement:
Raises
-------
MissingApplicationID
The application ID could not be found.
NotFound
The entitlement could not be found.
HTTPException
Deleting the entitlement failed.
"""
if self.application_id is None:
raise MissingApplicationID
await self._state.http.delete_entitlement(self.application_id, self.id)