Source code for openapi.pagination.cursor

import base64
from dataclasses import dataclass
from datetime import date, datetime
from functools import cached_property
from typing import Any, Dict, Optional, Tuple, Type

from dateutil.parser import parse as parse_date
from yarl import URL

from openapi import json
from openapi.data.fields import Choice, integer_field, str_field
from openapi.data.validate import ValidationErrors

from .pagination import (
    DEF_PAGINATION_LIMIT,
    MAX_PAGINATION_LIMIT,
    Pagination,
    PaginationVisitor,
    fields_flip_sign,
    fields_no_sign,
    from_filters_and_dataclass,
)

CursorType = Tuple[Tuple[str, str], ...]


def encode_cursor(data: Tuple[str, ...], previous: bool = False) -> str:
    cursor_bytes = json.dumps((data, previous)).encode("ascii")
    base64_bytes = base64.b64encode(cursor_bytes)
    return base64_bytes.decode("ascii")


def decode_cursor(
    cursor: Optional[str], field_names: Tuple[str]
) -> Tuple[CursorType, bool]:
    try:
        if cursor:
            base64_bytes = cursor.encode("ascii")
            cursor_bytes = base64.b64decode(base64_bytes)
            values, previous = json.loads(cursor_bytes)
            if len(values) == len(field_names):
                return tuple(zip(field_names, values)), previous
            raise ValueError
        return (), False
    except Exception as e:
        raise ValidationErrors("invalid cursor") from e


def cursor_url(url: URL, cursor: str) -> URL:
    query = url.query.copy()
    query.update(_cursor=cursor)
    return url.with_query(query)


def start_values(record: dict, field_names: Tuple[str, ...]) -> Tuple[str, ...]:
    """start values for pagination"""
    return tuple(record[field] for field in field_names)


[docs]def cursorPagination( *order_by_fields: str, default_limit: int = DEF_PAGINATION_LIMIT, max_limit: int = MAX_PAGINATION_LIMIT, ) -> Type[Pagination]: if len(order_by_fields) == 0: raise ValueError("orderable_fields must be specified") field_names = fields_no_sign(order_by_fields) @dataclass class CursorPagination(Pagination): limit: int = integer_field( min_value=1, max_value=max_limit, default=default_limit, required=False, description="Limit the number of objects returned from the endpoint", ) direction: str = str_field( validator=Choice(("asc", "desc")), required=False, default="asc", description=( f"Sort results via `{', '.join(order_by_fields)}` " "in descending or ascending order" ), ) _cursor: str = str_field(default="", hidden=True) @cached_property def cursor_info(self) -> Tuple[CursorType, Tuple[str, ...], bool]: order_by = ( fields_flip_sign(order_by_fields) if self.direction == "desc" else order_by_fields ) cursor, previous = decode_cursor(self._cursor, order_by) return cursor, order_by, previous @property def previous(self) -> bool: return self.cursor_info[2] def apply(self, visitor: PaginationVisitor) -> None: cursor, order_by, previous = self.cursor_info visitor.apply_cursor_pagination( cursor, self.limit, order_by, previous=previous, ) @classmethod def create_pagination(cls, data: dict) -> "CursorPagination": return from_filters_and_dataclass(CursorPagination, data) def links( self, url: URL, data: list, total: Optional[int] = None ) -> Dict[str, str]: links = {} if self.previous: if len(data) > self.limit + 1: links["prev"] = cursor_url( url, encode_cursor( start_values(data[self.limit], field_names), previous=True ), ) if self._cursor: links["next"] = cursor_url( url, encode_cursor( start_values(data[0], field_names), ), ) else: if len(data) > self.limit: links["next"] = cursor_url( url, encode_cursor(start_values(data[self.limit], field_names)), ) if self._cursor: links["prev"] = cursor_url( url, encode_cursor( start_values(data[0], field_names), previous=True, ), ) return links def get_data(self, data: list) -> list: if self.previous: data = list(reversed(data[1:])) return data if len(data) <= self.limit else data[1:] return data if len(data) <= self.limit else data[: self.limit] return CursorPagination
def cursor_to_python(py_type: Type, value: Any) -> Any: try: if py_type is datetime: return parse_date(value) elif py_type is date: return parse_date(value).date() elif py_type is int: return int(value) else: return value except Exception as e: raise ValidationErrors("invalid cursor") from e