Codebase list python-graphene-sqlalchemy / master graphene_sqlalchemy / utils.py
master

Tree @master (Download .tar.gz)

utils.py @masterraw · history · blame

from sqlalchemy.exc import ArgumentError
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import class_mapper, object_mapper
from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError

from graphene import Argument, Enum, List


def get_session(context):
    return context.get("session")


def get_query(model, context):
    query = getattr(model, "query", None)
    if not query:
        session = get_session(context)
        if not session:
            raise Exception(
                "A query in the model Base or a session in the schema is required for querying.\n"
                "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying"
            )
        query = session.query(model)
    return query


def is_mapped_class(cls):
    try:
        class_mapper(cls)
    except (ArgumentError, UnmappedClassError):
        return False
    else:
        return True


def is_mapped_instance(cls):
    try:
        object_mapper(cls)
    except (ArgumentError, UnmappedInstanceError):
        return False
    else:
        return True


def _symbol_name(column_name, is_asc):
    return column_name + ("_asc" if is_asc else "_desc")


class EnumValue(str):
    """Subclass of str that stores a string and an arbitrary value in the "value" property"""

    def __new__(cls, str_value, value):
        return super(EnumValue, cls).__new__(cls, str_value)

    def __init__(self, str_value, value):
        super(EnumValue, self).__init__()
        self.value = value


# Cache for the generated enums, to avoid name clash
_ENUM_CACHE = {}


def _sort_enum_for_model(cls, name=None, symbol_name=_symbol_name):
    name = name or cls.__name__ + "SortEnum"
    if name in _ENUM_CACHE:
        return _ENUM_CACHE[name]
    items = []
    default = []
    for column in inspect(cls).columns.values():
        asc_name = symbol_name(column.name, True)
        asc_value = EnumValue(asc_name, column.asc())
        desc_name = symbol_name(column.name, False)
        desc_value = EnumValue(desc_name, column.desc())
        if column.primary_key:
            default.append(asc_value)
        items.extend(((asc_name, asc_value), (desc_name, desc_value)))
    enum = Enum(name, items)
    _ENUM_CACHE[name] = (enum, default)
    return enum, default


def sort_enum_for_model(cls, name=None, symbol_name=_symbol_name):
    """Create Graphene Enum for sorting a SQLAlchemy class query

    Parameters
    - cls : Sqlalchemy model class
        Model used to create the sort enumerator
    - name : str, optional, default None
        Name to use for the enumerator. If not provided it will be set to `cls.__name__ + 'SortEnum'`
    - symbol_name : function, optional, default `_symbol_name`
        Function which takes the column name and a boolean indicating if the sort direction is ascending,
        and returns the symbol name for the current column and sort direction.
        The default function will create, for a column named 'foo', the symbols 'foo_asc' and 'foo_desc'

    Returns
    - Enum
        The Graphene enumerator
    """
    enum, _ = _sort_enum_for_model(cls, name, symbol_name)
    return enum


def sort_argument_for_model(cls, has_default=True):
    """Returns a Graphene argument for the sort field that accepts a list of sorting directions for a model.
    If `has_default` is True (the default) it will sort the result by the primary key(s)
    """
    enum, default = _sort_enum_for_model(cls)
    if not has_default:
        default = None
    return Argument(List(enum), default_value=default)