diff --git a/.gitignore b/.gitignore index e4070f3..a97b8c2 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ # Distribution / packaging .Python env/ +.venv/ build/ develop-eggs/ dist/ @@ -25,6 +26,7 @@ *.egg-info/ .installed.cfg *.egg +.python-version # PyInstaller # Usually these files are written by a python script from a template @@ -46,6 +48,7 @@ coverage.xml *,cover .pytest_cache/ +.benchmarks/ # Translations *.mo diff --git a/.travis.yml b/.travis.yml index 39151a5..5a98842 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,9 +4,6 @@ # Python 2.7 - env: TOXENV=py27 python: 2.7 - # Python 3.5 - - env: TOXENV=py34 - python: 3.4 # Python 3.5 - env: TOXENV=py35 python: 3.5 diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..879520f --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +/ @cito @jnak @Nabellaleen diff --git a/README.md b/README.md index 2ba0d1c..9b61706 100644 --- a/README.md +++ b/README.md @@ -43,10 +43,10 @@ class User(SQLAlchemyObjectType): class Meta: model = UserModel - # only return specified fields - only_fields = ("name",) - # exclude specified fields - exclude_fields = ("last_name",) + # use `only_fields` to only expose specific fields ie "name" + # only_fields = ("name",) + # use `exclude_fields` to exclude specific fields ie "last_name" + # exclude_fields = ("last_name",) class Query(graphene.ObjectType): users = graphene.List(User) diff --git a/docs/examples.rst b/docs/examples.rst index 283a0f5..2013cfb 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -13,20 +13,10 @@ interfaces = (relay.Node,) - class BookConnection(relay.Connection): - class Meta: - node = Book - - class Author(SQLAlchemyObjectType): class Meta: model = AuthorModel interfaces = (relay.Node,) - - - class AuthorConnection(relay.Connection): - class Meta: - node = Author class SearchResult(graphene.Union): @@ -39,8 +29,8 @@ search = graphene.List(SearchResult, q=graphene.String()) # List field for search results # Normal Fields - all_books = SQLAlchemyConnectionField(BookConnection) - all_authors = SQLAlchemyConnectionField(AuthorConnection) + all_books = SQLAlchemyConnectionField(Book.connection) + all_authors = SQLAlchemyConnectionField(Author.connection) def resolve_search(self, info, **args): q = args.get("q") # Search query diff --git a/docs/tips.rst b/docs/tips.rst index 1fd3910..baa8233 100644 --- a/docs/tips.rst +++ b/docs/tips.rst @@ -50,13 +50,8 @@ model = Pet - class PetConnection(Connection): - class Meta: - node = PetNode - - class Query(ObjectType): - allPets = SQLAlchemyConnectionField(PetConnection) + allPets = SQLAlchemyConnectionField(PetNode.connection) some of the allowed queries are diff --git a/docs/tutorial.rst b/docs/tutorial.rst index bc5ee62..3c4c135 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -102,28 +102,18 @@ interfaces = (relay.Node, ) - class DepartmentConnection(relay.Connection): - class Meta: - node = Department - - class Employee(SQLAlchemyObjectType): class Meta: model = EmployeeModel interfaces = (relay.Node, ) - class EmployeeConnection(relay.Connection): - class Meta: - node = Employee - - class Query(graphene.ObjectType): node = relay.Node.Field() # Allows sorting over multiple columns, by default over the primary key - all_employees = SQLAlchemyConnectionField(EmployeeConnection) + all_employees = SQLAlchemyConnectionField(Employee.connection) # Disable sorting over this field - all_departments = SQLAlchemyConnectionField(DepartmentConnection, sort=None) + all_departments = SQLAlchemyConnectionField(Department.connection, sort=None) schema = graphene.Schema(query=Query) diff --git a/examples/flask_sqlalchemy/README.md b/examples/flask_sqlalchemy/README.md index 7e44686..d08b484 100644 --- a/examples/flask_sqlalchemy/README.md +++ b/examples/flask_sqlalchemy/README.md @@ -9,7 +9,7 @@ --------------- First you'll need to get the source of the project. Do this by cloning the -whole Graphene repository: +whole Graphene-SQLAlchemy repository: ```bash # Get the example project code diff --git a/examples/flask_sqlalchemy/app.py b/examples/flask_sqlalchemy/app.py index a4d3f29..1066020 100755 --- a/examples/flask_sqlalchemy/app.py +++ b/examples/flask_sqlalchemy/app.py @@ -1,43 +1,46 @@ #!/usr/bin/env python +from database import db_session, init_db from flask import Flask +from schema import schema from flask_graphql import GraphQLView - -from .database import db_session, init_db -from .schema import schema app = Flask(__name__) app.debug = True -default_query = ''' +example_query = """ { - allEmployees { + allEmployees(sort: [NAME_ASC, ID_ASC]) { edges { node { - id, - name, + id + name department { - id, + id name - }, + } role { - id, + id name } } } } -}'''.strip() +} +""" -app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=True)) +app.add_url_rule( + "/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=True) +) @app.teardown_appcontext def shutdown_session(exception=None): db_session.remove() -if __name__ == '__main__': + +if __name__ == "__main__": init_db() app.run() diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index 01e76ca..ca4d412 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -14,7 +14,7 @@ # import all modules here that might define models so that # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() - from .models import Department, Employee, Role + from models import Department, Employee, Role Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) diff --git a/examples/flask_sqlalchemy/models.py b/examples/flask_sqlalchemy/models.py index e164c01..efbbe69 100644 --- a/examples/flask_sqlalchemy/models.py +++ b/examples/flask_sqlalchemy/models.py @@ -1,7 +1,6 @@ +from database import Base from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func from sqlalchemy.orm import backref, relationship - -from .database import Base class Department(Base): diff --git a/examples/flask_sqlalchemy/requirements.txt b/examples/flask_sqlalchemy/requirements.txt index 337ff60..fa2c13a 100644 --- a/examples/flask_sqlalchemy/requirements.txt +++ b/examples/flask_sqlalchemy/requirements.txt @@ -1,4 +1,2 @@ -graphene[sqlalchemy] -SQLAlchemy==1.0.11 -Flask==0.12.4 -Flask-GraphQL==1.3.0 +-e ../../ +Flask-GraphQL diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index cbee081..ea525e3 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -1,11 +1,10 @@ +from models import Department as DepartmentModel +from models import Employee as EmployeeModel +from models import Role as RoleModel + import graphene from graphene import relay -from graphene_sqlalchemy import (SQLAlchemyConnectionField, - SQLAlchemyObjectType, utils) - -from .models import Department as DepartmentModel -from .models import Employee as EmployeeModel -from .models import Role as RoleModel +from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType class Department(SQLAlchemyObjectType): @@ -26,22 +25,15 @@ interfaces = (relay.Node, ) -SortEnumEmployee = utils.sort_enum_for_model(EmployeeModel, 'SortEnumEmployee', - lambda c, d: c.upper() + ('_ASC' if d else '_DESC')) - - class Query(graphene.ObjectType): node = relay.Node.Field() # Allow only single column sorting all_employees = SQLAlchemyConnectionField( - Employee, - sort=graphene.Argument( - SortEnumEmployee, - default_value=utils.EnumValue('id_asc', EmployeeModel.id.asc()))) + Employee.connection, sort=Employee.sort_argument()) # Allows sorting over multiple columns, by default over the primary key - all_roles = SQLAlchemyConnectionField(Role) + all_roles = SQLAlchemyConnectionField(Role.connection) # Disable sorting over this field - all_departments = SQLAlchemyConnectionField(Department, sort=None) + all_departments = SQLAlchemyConnectionField(Department.connection, sort=None) -schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) +schema = graphene.Schema(query=Query) diff --git a/examples/nameko_sqlalchemy/README.md b/examples/nameko_sqlalchemy/README.md index 39cfe92..e080389 100644 --- a/examples/nameko_sqlalchemy/README.md +++ b/examples/nameko_sqlalchemy/README.md @@ -14,7 +14,7 @@ --------------- First you'll need to get the source of the project. Do this by cloning the -whole Graphene repository: +whole Graphene-SQLAlchemy repository: ```bash # Get the example project code @@ -46,7 +46,6 @@ ```bash ./run.sh - ``` Now head on over to postman and send POST request to: diff --git a/examples/nameko_sqlalchemy/app.py b/examples/nameko_sqlalchemy/app.py index 42a40a0..0535252 100755 --- a/examples/nameko_sqlalchemy/app.py +++ b/examples/nameko_sqlalchemy/app.py @@ -1,9 +1,9 @@ +from database import db_session, init_db +from schema import schema + from graphql_server import (HttpQueryError, default_format_error, encode_execution_results, json_encode, load_json_body, run_http_query) - -from .database import db_session, init_db -from .schema import schema class App(): diff --git a/examples/nameko_sqlalchemy/database.py b/examples/nameko_sqlalchemy/database.py index 01e76ca..ca4d412 100644 --- a/examples/nameko_sqlalchemy/database.py +++ b/examples/nameko_sqlalchemy/database.py @@ -14,7 +14,7 @@ # import all modules here that might define models so that # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() - from .models import Department, Employee, Role + from models import Department, Employee, Role Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) diff --git a/examples/nameko_sqlalchemy/models.py b/examples/nameko_sqlalchemy/models.py index e164c01..efbbe69 100644 --- a/examples/nameko_sqlalchemy/models.py +++ b/examples/nameko_sqlalchemy/models.py @@ -1,7 +1,6 @@ +from database import Base from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func from sqlalchemy.orm import backref, relationship - -from .database import Base class Department(Base): diff --git a/examples/nameko_sqlalchemy/requirements.txt b/examples/nameko_sqlalchemy/requirements.txt index be037f7..617d487 100644 --- a/examples/nameko_sqlalchemy/requirements.txt +++ b/examples/nameko_sqlalchemy/requirements.txt @@ -1,4 +1,3 @@ -graphene[sqlalchemy] -SQLAlchemy==1.0.11 +-e ../../ +graphql-server-core nameko -graphql-server-core diff --git a/examples/nameko_sqlalchemy/schema.py b/examples/nameko_sqlalchemy/schema.py index fa74735..ced300b 100644 --- a/examples/nameko_sqlalchemy/schema.py +++ b/examples/nameko_sqlalchemy/schema.py @@ -1,38 +1,35 @@ +from models import Department as DepartmentModel +from models import Employee as EmployeeModel +from models import Role as RoleModel + import graphene from graphene import relay from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType -from .models import Department as DepartmentModel -from .models import Employee as EmployeeModel -from .models import Role as RoleModel - class Department(SQLAlchemyObjectType): - class Meta: model = DepartmentModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Employee(SQLAlchemyObjectType): - class Meta: model = EmployeeModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Role(SQLAlchemyObjectType): - class Meta: model = RoleModel - interfaces = (relay.Node, ) + interfaces = (relay.Node,) class Query(graphene.ObjectType): node = relay.Node.Field() - all_employees = SQLAlchemyConnectionField(Employee) - all_roles = SQLAlchemyConnectionField(Role) + all_employees = SQLAlchemyConnectionField(Employee.connection) + all_roles = SQLAlchemyConnectionField(Role.connection) role = graphene.Field(Role) -schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) +schema = graphene.Schema(query=Query) diff --git a/examples/nameko_sqlalchemy/service.py b/examples/nameko_sqlalchemy/service.py index 9815750..d9c519c 100644 --- a/examples/nameko_sqlalchemy/service.py +++ b/examples/nameko_sqlalchemy/service.py @@ -1,7 +1,6 @@ #!/usr/bin/env python +from app import App from nameko.web.handlers import http - -from .app import App class DepartmentService: diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index d328304..3945d50 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -2,7 +2,7 @@ from .fields import SQLAlchemyConnectionField from .utils import get_query, get_session -__version__ = "2.1.2" +__version__ = "2.3.0" __all__ = [ "__version__", diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py new file mode 100644 index 0000000..baf01de --- /dev/null +++ b/graphene_sqlalchemy/batching.py @@ -0,0 +1,72 @@ +import sqlalchemy +from promise import dataloader, promise +from sqlalchemy.orm import Session, strategies +from sqlalchemy.orm.query import QueryContext + + +def get_batch_resolver(relationship_prop): + + # Cache this across `batch_load_fn` calls + # This is so SQL string generation is cached under-the-hood via `bakery` + selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) + + class RelationshipLoader(dataloader.DataLoader): + cache = False + + def batch_load_fn(self, parents): # pylint: disable=method-hidden + """ + Batch loads the relationships of all the parents as one SQL statement. + + There is no way to do this out-of-the-box with SQLAlchemy but + we can piggyback on some internal APIs of the `selectin` + eager loading strategy. It's a bit hacky but it's preferable + than re-implementing and maintainnig a big chunk of the `selectin` + loader logic ourselves. + + The approach here is to build a regular query that + selects the parent and `selectin` load the relationship. + But instead of having the query emits 2 `SELECT` statements + when callling `all()`, we skip the first `SELECT` statement + and jump right before the `selectin` loader is called. + To accomplish this, we have to construct objects that are + normally built in the first part of the query in order + to call directly `SelectInLoader._load_for_path`. + + TODO Move this logic to a util in the SQLAlchemy repo as per + SQLAlchemy's main maitainer suggestion. + See https://git.io/JewQ7 + """ + child_mapper = relationship_prop.mapper + parent_mapper = relationship_prop.parent + session = Session.object_session(parents[0]) + + # These issues are very unlikely to happen in practice... + for parent in parents: + # assert parent.__mapper__ is parent_mapper + # All instances must share the same session + assert session is Session.object_session(parent) + # The behavior of `selectin` is undefined if the parent is dirty + assert parent not in session.dirty + + # Should the boolean be set to False? Does it matter for our purposes? + states = [(sqlalchemy.inspect(parent), True) for parent in parents] + + # For our purposes, the query_context will only used to get the session + query_context = QueryContext(session.query(parent_mapper.entity)) + + selectin_loader._load_for_path( + query_context, + parent_mapper._path_registry, + states, + None, + child_mapper, + ) + + return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents]) + + loader = RelationshipLoader() + + def resolve(root, info, **args): + return loader.load(root) + + return resolve diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 7cc259e..f4b805e 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -1,11 +1,20 @@ +from enum import EnumMeta + from singledispatch import singledispatch from sqlalchemy import types from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import interfaces +from sqlalchemy.orm import interfaces, strategies from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, String) from graphene.types.json import JSONString + +from .batching import get_batch_resolver +from .enums import enum_for_sa_enum +from .fields import (BatchSQLAlchemyConnectionField, + default_connection_field_factory) +from .registry import get_global_registry +from .resolvers import get_attr_resolver, get_custom_resolver try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType @@ -13,6 +22,9 @@ ChoiceType = JSONType = ScalarListType = TSVectorType = object +is_selectin_available = getattr(strategies, 'SelectInLoader', None) + + def get_column_doc(column): return getattr(column, "doc", None) @@ -21,43 +33,110 @@ return bool(getattr(column, "nullable", True)) -def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory): - direction = relationship.direction - model = relationship.mapper.entity - +def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching, + orm_field_name, **field_kwargs): + """ + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param function|None connection_field_factory: + :param bool batching: + :param str orm_field_name: + :param dict field_kwargs: + :rtype: Dynamic + """ def dynamic_type(): - _type = registry.get_type_for_model(model) - if not _type: + """:rtype: Field|None""" + direction = relationship_prop.direction + child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + batching_ = batching if is_selectin_available else False + + if not child_type: return None - if direction == interfaces.MANYTOONE or not relationship.uselist: - return Field(_type) - elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): - if _type._meta.connection: - return connection_field_factory(relationship, registry) - return Field(List(_type)) + + if direction == interfaces.MANYTOONE or not relationship_prop.uselist: + return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name, + **field_kwargs) + + if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): + return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_, + connection_field_factory, **field_kwargs) return Dynamic(dynamic_type) -def convert_sqlalchemy_hybrid_method(hybrid_item): - return String(description=getattr(hybrid_item, "__doc__", None), required=False) - - -def convert_sqlalchemy_composite(composite, registry): - converter = registry.get_converter_for_composite(composite.composite_class) +def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs): + """ + Convert one-to-one or many-to-one relationshsip. Return an object field. + + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param bool batching: + :param str orm_field_name: + :param dict field_kwargs: + :rtype: Field + """ + child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + + resolver = get_custom_resolver(obj_type, orm_field_name) + if resolver is None: + resolver = get_batch_resolver(relationship_prop) if batching else \ + get_attr_resolver(obj_type, relationship_prop.key) + + return Field(child_type, resolver=resolver, **field_kwargs) + + +def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs): + """ + Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. + + :param sqlalchemy.RelationshipProperty relationship_prop: + :param SQLAlchemyObjectType obj_type: + :param bool batching: + :param function|None connection_field_factory: + :param dict field_kwargs: + :rtype: Field + """ + child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) + + if not child_type._meta.connection: + return Field(List(child_type), **field_kwargs) + + # TODO Allow override of connection_field_factory and resolver via ORMField + if connection_field_factory is None: + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \ + default_connection_field_factory + + return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs) + + +def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): + if 'type' not in field_kwargs: + # TODO The default type should be dependent on the type of the property propety. + field_kwargs['type'] = String + + return Field( + resolver=resolver, + **field_kwargs + ) + + +def convert_sqlalchemy_composite(composite_prop, registry, resolver): + converter = registry.get_converter_for_composite(composite_prop.composite_class) if not converter: try: raise Exception( "Don't know how to convert the composite field %s (%s)" - % (composite, composite.composite_class) + % (composite_prop, composite_prop.composite_class) ) except AttributeError: # handle fields that are not attached to a class yet (don't have a parent) raise Exception( "Don't know how to convert the composite field %r (%s)" - % (composite, composite.composite_class) + % (composite_prop, composite_prop.composite_class) ) - return converter(composite, registry) + + # TODO Add a way to override composite fields default parameters + return converter(composite_prop, registry) def _register_composite_class(cls, registry=None): @@ -75,8 +154,16 @@ convert_sqlalchemy_composite.register = _register_composite_class -def convert_sqlalchemy_column(column, registry=None): - return convert_sqlalchemy_type(getattr(column, "type", None), column, registry) +def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): + column = column_prop.columns[0] + field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry)) + field_kwargs.setdefault('required', not is_column_nullable(column)) + field_kwargs.setdefault('description', get_column_doc(column)) + + return Field( + resolver=resolver, + **field_kwargs + ) @singledispatch @@ -98,99 +185,69 @@ @convert_sqlalchemy_type.register(postgresql.CIDR) @convert_sqlalchemy_type.register(TSVectorType) def convert_column_to_string(type, column, registry=None): - return String( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return String @convert_sqlalchemy_type.register(types.DateTime) def convert_column_to_datetime(type, column, registry=None): from graphene.types.datetime import DateTime - - return DateTime( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return DateTime @convert_sqlalchemy_type.register(types.SmallInteger) @convert_sqlalchemy_type.register(types.Integer) def convert_column_to_int_or_id(type, column, registry=None): - if column.primary_key: - return ID( - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) - else: - return Int( - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) + return ID if column.primary_key else Int @convert_sqlalchemy_type.register(types.Boolean) def convert_column_to_boolean(type, column, registry=None): - return Boolean( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return Boolean @convert_sqlalchemy_type.register(types.Float) @convert_sqlalchemy_type.register(types.Numeric) @convert_sqlalchemy_type.register(types.BigInteger) def convert_column_to_float(type, column, registry=None): - return Float( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return Float @convert_sqlalchemy_type.register(types.Enum) def convert_enum_to_enum(type, column, registry=None): - enum_class = getattr(type, 'enum_class', None) - if enum_class: # Check if an enum.Enum type is used - graphene_type = Enum.from_enum(enum_class) - else: # Nope, just a list of string options - items = zip(type.enums, type.enums) - graphene_type = Enum(type.name, items) - return Field( - graphene_type, - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) - - + return lambda: enum_for_sa_enum(type, registry or get_global_registry()) + + +# TODO Make ChoiceType conversion consistent with other enums @convert_sqlalchemy_type.register(ChoiceType) -def convert_column_to_enum(type, column, registry=None): +def convert_choice_to_enum(type, column, registry=None): name = "{}_{}".format(column.table.name, column.name).upper() - return Enum(name, type.choices, description=get_column_doc(column)) + if isinstance(type.choices, EnumMeta): + # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta + # do not use from_enum here because we can have more than one enum column in table + return Enum(name, list((v.name, v.value) for v in type.choices)) + else: + return Enum(name, type.choices) @convert_sqlalchemy_type.register(ScalarListType) def convert_scalar_list_to_list(type, column, registry=None): - return List(String, description=get_column_doc(column)) - - + return List(String) + + +@convert_sqlalchemy_type.register(types.ARRAY) @convert_sqlalchemy_type.register(postgresql.ARRAY) -def convert_postgres_array_to_list(_type, column, registry=None): - graphene_type = convert_sqlalchemy_type(column.type.item_type, column) - inner_type = type(graphene_type) - return List( - inner_type, - description=get_column_doc(column), - required=not (is_column_nullable(column)), - ) +def convert_array_to_list(_type, column, registry=None): + inner_type = convert_sqlalchemy_type(column.type.item_type, column) + return List(inner_type) @convert_sqlalchemy_type.register(postgresql.HSTORE) @convert_sqlalchemy_type.register(postgresql.JSON) @convert_sqlalchemy_type.register(postgresql.JSONB) def convert_json_to_string(type, column, registry=None): - return JSONString( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return JSONString @convert_sqlalchemy_type.register(JSONType) def convert_json_type_to_string(type, column, registry=None): - return JSONString( - description=get_column_doc(column), required=not (is_column_nullable(column)) - ) + return JSONString diff --git a/graphene_sqlalchemy/enums.py b/graphene_sqlalchemy/enums.py new file mode 100644 index 0000000..0adea10 --- /dev/null +++ b/graphene_sqlalchemy/enums.py @@ -0,0 +1,206 @@ +import six +from sqlalchemy.orm import ColumnProperty +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Argument, Enum, List + +from .utils import EnumValue, to_enum_value_name, to_type_name + + +def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): + """Convert the given SQLAlchemy Enum type to a Graphene Enum type. + + The name of the Graphene Enum will be determined as follows: + If the SQLAlchemy Enum is based on a Python Enum, use the name + of the Python Enum. Otherwise, if the SQLAlchemy Enum is named, + use the SQL name after conversion to a type name. Otherwise, use + the given fallback_name or raise an error if it is empty. + + The Enum value names are converted to upper case if necessary. + """ + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) + ) + enum_class = sa_enum.enum_class + if enum_class: + if all(to_enum_value_name(key) == key for key in enum_class.__members__): + return Enum.from_enum(enum_class) + name = enum_class.__name__ + members = [ + (to_enum_value_name(key), value.value) + for key, value in enum_class.__members__.items() + ] + else: + sql_enum_name = sa_enum.name + if sql_enum_name: + name = to_type_name(sql_enum_name) + elif fallback_name: + name = fallback_name + else: + raise TypeError("No type name specified for {!r}".format(sa_enum)) + members = [(to_enum_value_name(key), key) for key in sa_enum.enums] + return Enum(name, members) + + +def enum_for_sa_enum(sa_enum, registry): + """Return the Graphene Enum type for the specified SQLAlchemy Enum type.""" + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) + ) + enum = registry.get_graphene_enum_for_sa_enum(sa_enum) + if not enum: + enum = _convert_sa_to_graphene_enum(sa_enum) + registry.register_enum(sa_enum, enum) + return enum + + +def enum_for_field(obj_type, field_name): + """Return the Graphene Enum type for the specified Graphene field.""" + from .types import SQLAlchemyObjectType + + if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) + if not field_name or not isinstance(field_name, six.string_types): + raise TypeError( + "Expected a field name, but got: {!r}".format(field_name)) + registry = obj_type._meta.registry + orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) + if orm_field is None: + raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name)) + if not isinstance(orm_field, ColumnProperty): + raise TypeError( + "{}.{} does not map to model column".format(obj_type._meta.name, field_name) + ) + column = orm_field.columns[0] + sa_enum = column.type + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "{}.{} does not map to enum column".format(obj_type._meta.name, field_name) + ) + enum = registry.get_graphene_enum_for_sa_enum(sa_enum) + if not enum: + fallback_name = obj_type._meta.name + to_type_name(field_name) + enum = _convert_sa_to_graphene_enum(sa_enum, fallback_name) + registry.register_enum(sa_enum, enum) + return enum + + +def _default_sort_enum_symbol_name(column_name, sort_asc=True): + return to_enum_value_name(column_name) + ("_ASC" if sort_asc else "_DESC") + + +def sort_enum_for_object_type( + obj_type, name=None, only_fields=None, only_indexed=None, get_symbol_name=None +): + """Return Graphene Enum for sorting the given SQLAlchemyObjectType. + + Parameters + - obj_type : SQLAlchemyObjectType + The object type for which the sort Enum shall be generated. + - name : str, optional, default None + Name to use for the sort Enum. + If not provided, it will be set to the object type name + 'SortEnum' + - only_fields : sequence, optional, default None + If this is set, only fields from this sequence will be considered. + - only_indexed : bool, optional, default False + If this is set, only indexed columns will be considered. + - get_symbol_name : function, optional, default None + Function which takes the column name and a boolean indicating + if the sort direction is ascending, and returns the symbol name + for the current column and sort direction. If no such function + is passed, a default function will be used that creates the symbols + 'foo_asc' and 'foo_desc' for a column with the name 'foo'. + + Returns + - Enum + The Graphene Enum type + """ + name = name or obj_type._meta.name + "SortEnum" + registry = obj_type._meta.registry + enum = registry.get_sort_enum_for_object_type(obj_type) + custom_options = dict( + only_fields=only_fields, + only_indexed=only_indexed, + get_symbol_name=get_symbol_name, + ) + if enum: + if name != enum.__name__ or custom_options != enum.custom_options: + raise ValueError( + "Sort enum for {} has already been customized".format(obj_type) + ) + else: + members = [] + default = [] + fields = obj_type._meta.fields + get_name = get_symbol_name or _default_sort_enum_symbol_name + for field_name in fields: + if only_fields and field_name not in only_fields: + continue + orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) + if not isinstance(orm_field, ColumnProperty): + continue + column = orm_field.columns[0] + if only_indexed and not (column.primary_key or column.index): + continue + asc_name = get_name(column.name, True) + asc_value = EnumValue(asc_name, column.asc()) + desc_name = get_name(column.name, False) + desc_value = EnumValue(desc_name, column.desc()) + if column.primary_key: + default.append(asc_value) + members.extend(((asc_name, asc_value), (desc_name, desc_value))) + enum = Enum(name, members) + enum.default = default # store default as attribute + enum.custom_options = custom_options + registry.register_sort_enum(obj_type, enum) + return enum + + +def sort_argument_for_object_type( + obj_type, + enum_name=None, + only_fields=None, + only_indexed=None, + get_symbol_name=None, + has_default=True, +): + """"Returns Graphene Argument for sorting the given SQLAlchemyObjectType. + + Parameters + - obj_type : SQLAlchemyObjectType + The object type for which the sort Argument shall be generated. + - enum_name : str, optional, default None + Name to use for the sort Enum. + If not provided, it will be set to the object type name + 'SortEnum' + - only_fields : sequence, optional, default None + If this is set, only fields from this sequence will be considered. + - only_indexed : bool, optional, default False + If this is set, only indexed columns will be considered. + - get_symbol_name : function, optional, default None + Function which takes the column name and a boolean indicating + if the sort direction is ascending, and returns the symbol name + for the current column and sort direction. If no such function + is passed, a default function will be used that creates the symbols + 'foo_asc' and 'foo_desc' for a column with the name 'foo'. + - has_default : bool, optional, default True + If this is set to False, no sorting will happen when this argument is not + passed. Otherwise results will be sortied by the primary key(s) of the model. + + Returns + - Enum + A Graphene Argument that accepts a list of sorting directions for the model. + """ + enum = sort_enum_for_object_type( + obj_type, + enum_name, + only_fields=only_fields, + only_indexed=only_indexed, + get_symbol_name=get_symbol_name, + ) + if not has_default: + enum.default = None + + return Argument(List(enum), default_value=enum.default) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 4a46b74..780fcbf 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,16 +1,17 @@ -import logging +import warnings from functools import partial +import six from promise import Promise, is_thenable from sqlalchemy.orm.query import Query +from graphene import NonNull from graphene.relay import Connection, ConnectionField from graphene.relay.connection import PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice -from .utils import get_query, sort_argument_for_model - -log = logging.getLogger() +from .batching import get_batch_resolver +from .utils import get_query class UnsortedSQLAlchemyConnectionField(ConnectionField): @@ -19,29 +20,30 @@ from .types import SQLAlchemyObjectType _type = super(ConnectionField, self).type - if issubclass(_type, Connection): + nullable_type = get_nullable_type(_type) + if issubclass(nullable_type, Connection): return _type - assert issubclass(_type, SQLAlchemyObjectType), ( + assert issubclass(nullable_type, SQLAlchemyObjectType), ( "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" - ).format(_type.__name__) - assert _type._meta.connection, "The type {} doesn't have a connection".format( - _type.__name__ + ).format(nullable_type.__name__) + assert ( + nullable_type.connection + ), "The type {} doesn't have a connection".format( + nullable_type.__name__ ) - return _type._meta.connection + assert _type == nullable_type, ( + "Passing a SQLAlchemyObjectType instance is deprecated. " + "Pass the connection type instead accessible via SQLAlchemyObjectType.connection" + ) + return nullable_type.connection @property def model(self): - return self.type._meta.node._meta.model + return get_nullable_type(self.type)._meta.node._meta.model @classmethod - def get_query(cls, model, info, sort=None, **args): - query = get_query(model, info.context) - if sort is not None: - if isinstance(sort, str): - query = query.order_by(sort.value) - else: - query = query.order_by(*(col.value for col in sort)) - return query + def get_query(cls, model, info, **args): + return get_query(model, info.context) @classmethod def resolve_connection(cls, connection_type, model, info, args, resolved): @@ -76,59 +78,106 @@ return on_resolve(resolved) def get_resolver(self, parent_resolver): - return partial(self.connection_resolver, parent_resolver, self.type, self.model) + return partial( + self.connection_resolver, + parent_resolver, + get_nullable_type(self.type), + self.model, + ) +# TODO Rename this to SortableSQLAlchemyConnectionField class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): def __init__(self, type, *args, **kwargs): - if "sort" not in kwargs and issubclass(type, Connection): + nullable_type = get_nullable_type(type) + if "sort" not in kwargs and issubclass(nullable_type, Connection): # Let super class raise if type is not a Connection try: - model = type.Edge.node._type._meta.model - kwargs.setdefault("sort", sort_argument_for_model(model)) - except Exception: - raise Exception( + kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) + except (AttributeError, TypeError): + raise TypeError( 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' " to None to disabling the creation of the sort query argument".format( - type.__name__ + nullable_type.__name__ ) ) elif "sort" in kwargs and kwargs["sort"] is None: del kwargs["sort"] super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) + @classmethod + def get_query(cls, model, info, sort=None, **args): + query = get_query(model, info.context) + if sort is not None: + if isinstance(sort, six.string_types): + query = query.order_by(sort.value) + else: + query = query.order_by(*(col.value for col in sort)) + return query -def default_connection_field_factory(relationship, registry): + +class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): + """ + This is currently experimental. + The API and behavior may change in future versions. + Use at your own risk. + """ + + def get_resolver(self, parent_resolver): + return partial( + self.connection_resolver, + self.resolver, + get_nullable_type(self.type), + self.model, + ) + + @classmethod + def from_relationship(cls, relationship, registry, **field_kwargs): + model = relationship.mapper.entity + model_type = registry.get_type_for_model(model) + return cls(model_type.connection, resolver=get_batch_resolver(relationship), **field_kwargs) + + +def default_connection_field_factory(relationship, registry, **field_kwargs): model = relationship.mapper.entity model_type = registry.get_type_for_model(model) - return createConnectionField(model_type) + return __connectionFactory(model_type, **field_kwargs) # TODO Remove in next major version __connectionFactory = UnsortedSQLAlchemyConnectionField -def createConnectionField(_type): - log.warn( +def createConnectionField(_type, **field_kwargs): + warnings.warn( 'createConnectionField is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + DeprecationWarning, ) - return __connectionFactory(_type) + return __connectionFactory(_type, **field_kwargs) def registerConnectionFieldFactory(factoryMethod): - log.warn( + warnings.warn( 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + DeprecationWarning, ) global __connectionFactory __connectionFactory = factoryMethod def unregisterConnectionFieldFactory(): - log.warn( + warnings.warn( 'registerConnectionFieldFactory is deprecated and will be removed in the next ' - 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', + DeprecationWarning, ) global __connectionFactory __connectionFactory = UnsortedSQLAlchemyConnectionField + + +def get_nullable_type(_type): + if isinstance(_type, NonNull): + return _type.of_type + return _type diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 460053f..c20bc2c 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,31 +1,91 @@ +from collections import defaultdict + +import six +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Enum + + class Registry(object): def __init__(self): self._registry = {} self._registry_models = {} + self._registry_orm_fields = defaultdict(dict) self._registry_composites = {} + self._registry_enums = {} + self._registry_sort_enums = {} - def register(self, cls): + def register(self, obj_type): from .types import SQLAlchemyObjectType - assert issubclass(cls, SQLAlchemyObjectType), ( - "Only classes of type SQLAlchemyObjectType can be registered, " - 'received "{}"' - ).format(cls.__name__) - assert cls._meta.registry == self, "Registry for a Model have to match." + if not isinstance(obj_type, type) or not issubclass( + obj_type, SQLAlchemyObjectType + ): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) + ) + assert obj_type._meta.registry == self, "Registry for a Model have to match." # assert self.get_type_for_model(cls._meta.model) in [None, cls], ( # 'SQLAlchemy model "{}" already associated with ' # 'another type "{}".' # ).format(cls._meta.model, self._registry[cls._meta.model]) - self._registry[cls._meta.model] = cls + self._registry[obj_type._meta.model] = obj_type def get_type_for_model(self, model): return self._registry.get(model) + + def register_orm_field(self, obj_type, field_name, orm_field): + from .types import SQLAlchemyObjectType + + if not isinstance(obj_type, type) or not issubclass( + obj_type, SQLAlchemyObjectType + ): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) + ) + if not field_name or not isinstance(field_name, six.string_types): + raise TypeError("Expected a field name, but got: {!r}".format(field_name)) + self._registry_orm_fields[obj_type][field_name] = orm_field + + def get_orm_field_for_graphene_field(self, obj_type, field_name): + return self._registry_orm_fields.get(obj_type, {}).get(field_name) def register_composite_converter(self, composite, converter): self._registry_composites[composite] = converter def get_converter_for_composite(self, composite): return self._registry_composites.get(composite) + + def register_enum(self, sa_enum, graphene_enum): + if not isinstance(sa_enum, SQLAlchemyEnumType): + raise TypeError( + "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum) + ) + if not isinstance(graphene_enum, type(Enum)): + raise TypeError( + "Expected Graphene Enum, but got: {!r}".format(graphene_enum) + ) + + self._registry_enums[sa_enum] = graphene_enum + + def get_graphene_enum_for_sa_enum(self, sa_enum): + return self._registry_enums.get(sa_enum) + + def register_sort_enum(self, obj_type, sort_enum): + from .types import SQLAlchemyObjectType + + if not isinstance(obj_type, type) or not issubclass( + obj_type, SQLAlchemyObjectType + ): + raise TypeError( + "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) + ) + if not isinstance(sort_enum, type(Enum)): + raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum)) + self._registry_sort_enums[obj_type] = sort_enum + + def get_sort_enum_for_object_type(self, obj_type): + return self._registry_sort_enums.get(obj_type) registry = None diff --git a/graphene_sqlalchemy/resolvers.py b/graphene_sqlalchemy/resolvers.py new file mode 100644 index 0000000..83a6e35 --- /dev/null +++ b/graphene_sqlalchemy/resolvers.py @@ -0,0 +1,26 @@ +from graphene.utils.get_unbound_function import get_unbound_function + + +def get_custom_resolver(obj_type, orm_field_name): + """ + Since `graphene` will call `resolve_` on a field only if it + does not have a `resolver`, we need to re-implement that logic here so + users are able to override the default resolvers that we provide. + """ + resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) + if resolver: + return get_unbound_function(resolver) + + return None + + +def get_attr_resolver(obj_type, model_attr): + """ + In order to support field renaming via `ORMField.model_attr`, + we need to define resolver functions for each field. + + :param SQLAlchemyObjectType obj_type: + :param str model_attr: the name of the SQLAlchemy attribute + :rtype: Callable + """ + return lambda root, _info: getattr(root, model_attr, None) diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py new file mode 100644 index 0000000..9851505 --- /dev/null +++ b/graphene_sqlalchemy/tests/conftest.py @@ -0,0 +1,39 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +import graphene + +from ..converter import convert_sqlalchemy_composite +from ..registry import reset_global_registry +from .models import Base, CompositeFullName + +test_db_url = 'sqlite://' # use in-memory database for tests + + +@pytest.fixture(autouse=True) +def reset_registry(): + reset_global_registry() + + # Prevent tests that implicitly depend on Reporter from raising + # Tests that explicitly depend on this behavior should re-register a converter + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return graphene.Field(graphene.Int) + + +@pytest.yield_fixture(scope="function") +def session_factory(): + engine = create_engine(test_db_url) + Base.metadata.create_all(engine) + + yield sessionmaker(bind=engine) + + # SQLite in-memory db is deleted when its connection is closed. + # https://www.sqlite.org/inmemorydb.html + engine.dispose() + + +@pytest.fixture(scope="function") +def session(session_factory): + return session_factory() diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 3ba23a8..88e992b 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -2,12 +2,16 @@ import enum -from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table +from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, + func, select) from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import mapper, relationship +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import column_property, composite, mapper, relationship + +PetKind = Enum("cat", "dog", name="pet_kind") -class Hairkind(enum.Enum): +class HairKind(enum.Enum): LONG = 'long' SHORT = 'short' @@ -32,26 +36,44 @@ __tablename__ = "pets" id = Column(Integer(), primary_key=True) name = Column(String(30)) - pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False) - hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False) + pet_kind = Column(PetKind, nullable=False) + hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) reporter_id = Column(Integer(), ForeignKey("reporters.id")) + + +class CompositeFullName(object): + def __init__(self, first_name, last_name): + self.first_name = first_name + self.last_name = last_name + + def __composite_values__(self): + return self.first_name, self.last_name + + def __repr__(self): + return "{} {}".format(self.first_name, self.last_name) class Reporter(Base): __tablename__ = "reporters" + id = Column(Integer(), primary_key=True) - first_name = Column(String(30)) - last_name = Column(String(30)) - email = Column(String()) - pets = relationship("Pet", secondary=association_table, backref="reporters") + first_name = Column(String(30), doc="First name") + last_name = Column(String(30), doc="Last name") + email = Column(String(), doc="Email") + favorite_pet_kind = Column(PetKind) + pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id") articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) - # total = column_property( - # select([ - # func.cast(func.count(PersonInfo.id), Float) - # ]) - # ) + @hybrid_property + def hybrid_prop(self): + return self.first_name + + column_prop = column_property( + select([func.cast(func.count(id), Integer)]), doc="Column property" + ) + + composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") class Article(Base): diff --git a/graphene_sqlalchemy/tests/test_batching.py b/graphene_sqlalchemy/tests/test_batching.py new file mode 100644 index 0000000..fc646a3 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_batching.py @@ -0,0 +1,698 @@ +import contextlib +import logging + +import pytest + +import graphene +from graphene import relay + +from ..fields import (BatchSQLAlchemyConnectionField, + default_connection_field_factory) +from ..types import ORMField, SQLAlchemyObjectType +from .models import Article, HairKind, Pet, Reporter +from .utils import is_sqlalchemy_version_less_than, to_std_dicts + + +class MockLoggingHandler(logging.Handler): + """Intercept and store log messages in a list.""" + def __init__(self, *args, **kwargs): + self.messages = [] + logging.Handler.__init__(self, *args, **kwargs) + + def emit(self, record): + self.messages.append(record.getMessage()) + + +@contextlib.contextmanager +def mock_sqlalchemy_logging_handler(): + logging.basicConfig() + sql_logger = logging.getLogger('sqlalchemy.engine') + previous_level = sql_logger.level + + sql_logger.setLevel(logging.INFO) + mock_logging_handler = MockLoggingHandler() + mock_logging_handler.setLevel(logging.INFO) + sql_logger.addHandler(mock_logging_handler) + + yield mock_logging_handler + + sql_logger.setLevel(previous_level) + + +def get_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + batching = True + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + batching = True + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_articles(self, info): + return info.context.get('session').query(Article).all() + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + return graphene.Schema(query=Query) + + +if is_sqlalchemy_version_less_than('1.2'): + pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) + + +def test_many_to_one(session_factory): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + articles { + headline + reporter { + firstName + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN reporters' in message] + assert len(sql_statements) == 1 + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT articles.id AS articles_id, ' + 'articles.headline AS articles_headline, ' + 'articles.pub_date AS articles_pub_date, ' + 'articles.reporter_id AS articles_reporter_id \n' + 'FROM articles', + '()', + + 'SELECT reporters.id AS reporters_id, ' + '(SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters \n' + 'WHERE reporters.id IN (?, ?)', + '(1, 2)', + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "articles": [ + { + "headline": "Article_1", + "reporter": { + "firstName": "Reporter_1", + }, + }, + { + "headline": "Article_2", + "reporter": { + "firstName": "Reporter_2", + }, + }, + ], + } + + +def test_one_to_one(session_factory): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters { + firstName + favoriteArticle { + headline + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + assert len(sql_statements) == 1 + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.id AS reporters_id, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters', + '()', + + 'SELECT articles.reporter_id AS articles_reporter_id, ' + 'articles.id AS articles_id, ' + 'articles.headline AS articles_headline, ' + 'articles.pub_date AS articles_pub_date \n' + 'FROM articles \n' + 'WHERE articles.reporter_id IN (?, ?)', + '(1, 2)' + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "favoriteArticle": { + "headline": "Article_1", + }, + }, + { + "firstName": "Reporter_2", + "favoriteArticle": { + "headline": "Article_2", + }, + }, + ], + } + + +def test_one_to_many(session_factory): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_1 + session.add(article_2) + + article_3 = Article(headline='Article_3') + article_3.reporter = reporter_2 + session.add(article_3) + + article_4 = Article(headline='Article_4') + article_4.reporter = reporter_2 + session.add(article_4) + + session.commit() + session.close() + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + assert len(sql_statements) == 1 + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.id AS reporters_id, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters', + '()', + + 'SELECT articles.reporter_id AS articles_reporter_id, ' + 'articles.id AS articles_id, ' + 'articles.headline AS articles_headline, ' + 'articles.pub_date AS articles_pub_date \n' + 'FROM articles \n' + 'WHERE articles.reporter_id IN (?, ?)', + '(1, 2)' + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_1", + }, + }, + { + "node": { + "headline": "Article_2", + }, + }, + ], + }, + }, + { + "firstName": "Reporter_2", + "articles": { + "edges": [ + { + "node": { + "headline": "Article_3", + }, + }, + { + "node": { + "headline": "Article_4", + }, + }, + ], + }, + }, + ], + } + + +def test_many_to_many(session_factory): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_1) + + pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_2) + + reporter_1.pets.append(pet_1) + reporter_1.pets.append(pet_2) + + pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_3) + + pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_4) + + reporter_2.pets.append(pet_3) + reporter_2.pets.append(pet_4) + + session.commit() + session.close() + + schema = get_schema() + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + result = schema.execute(""" + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + assert len(messages) == 5 + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN pets' in message] + assert len(sql_statements) == 1 + return + + assert messages == [ + 'BEGIN (implicit)', + + 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' + 'reporters.id AS reporters_id, ' + 'reporters.first_name AS reporters_first_name, ' + 'reporters.last_name AS reporters_last_name, ' + 'reporters.email AS reporters_email, ' + 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' + 'FROM reporters', + '()', + + 'SELECT reporters_1.id AS reporters_1_id, ' + 'pets.id AS pets_id, ' + 'pets.name AS pets_name, ' + 'pets.pet_kind AS pets_pet_kind, ' + 'pets.hair_kind AS pets_hair_kind, ' + 'pets.reporter_id AS pets_reporter_id \n' + 'FROM reporters AS reporters_1 ' + 'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id ' + 'JOIN pets ON pets.id = association_1.pet_id \n' + 'WHERE reporters_1.id IN (?, ?) ' + 'ORDER BY pets.id', + '(1, 2)' + ] + + assert not result.errors + result = to_std_dicts(result.data) + assert result == { + "reporters": [ + { + "firstName": "Reporter_1", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_1", + }, + }, + { + "node": { + "name": "Pet_2", + }, + }, + ], + }, + }, + { + "firstName": "Reporter_2", + "pets": { + "edges": [ + { + "node": { + "name": "Pet_3", + }, + }, + { + "node": { + "name": "Pet_4", + }, + }, + ], + }, + }, + ], + } + + +def test_disable_batching_via_ormfield(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + + favorite_article = ORMField(batching=False) + articles = ORMField(batching=False) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + schema = graphene.Schema(query=Query) + + # Test one-to-one and many-to-one relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + favoriteArticle { + headline + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 + + # Test one-to-many and many-to-many relationships + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 + + +def test_connection_factory_field_overrides_batching_is_false(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = False + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + articles = ORMField(batching=False) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + schema = graphene.Schema(query=Query) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + if is_sqlalchemy_version_less_than('1.3'): + # The batched SQL statement generated is different in 1.2.x + # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` + # See https://git.io/JewQu + select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] + else: + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 1 + + +def test_connection_factory_field_overrides_batching_is_true(session_factory): + session = session_factory() + reporter_1 = Reporter(first_name='Reporter_1') + session.add(reporter_1) + reporter_2 = Reporter(first_name='Reporter_2') + session.add(reporter_2) + session.commit() + session.close() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + batching = True + connection_field_factory = default_connection_field_factory + + articles = ORMField(batching=True) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + + class Query(graphene.ObjectType): + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + schema = graphene.Schema(query=Query) + + with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: + # Starts new session to fully reset the engine / connection logging level + session = session_factory() + schema.execute(""" + query { + reporters { + articles { + edges { + node { + headline + } + } + } + } + } + """, context_value={"session": session}) + messages = sqlalchemy_logging_handler.messages + + select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] + assert len(select_statements) == 2 diff --git a/graphene_sqlalchemy/tests/test_benchmark.py b/graphene_sqlalchemy/tests/test_benchmark.py new file mode 100644 index 0000000..1e5ee4f --- /dev/null +++ b/graphene_sqlalchemy/tests/test_benchmark.py @@ -0,0 +1,226 @@ +import pytest +from graphql.backend import GraphQLCachedBackend, GraphQLCoreBackend + +import graphene +from graphene import relay + +from ..fields import BatchSQLAlchemyConnectionField +from ..types import SQLAlchemyObjectType +from .models import Article, HairKind, Pet, Reporter +from .utils import is_sqlalchemy_version_less_than + +if is_sqlalchemy_version_less_than('1.2'): + pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) + + +def get_schema(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (relay.Node,) + connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship + + class Query(graphene.ObjectType): + articles = graphene.Field(graphene.List(ArticleType)) + reporters = graphene.Field(graphene.List(ReporterType)) + + def resolve_articles(self, info): + return info.context.get('session').query(Article).all() + + def resolve_reporters(self, info): + return info.context.get('session').query(Reporter).all() + + return graphene.Schema(query=Query) + + +def benchmark_query(session_factory, benchmark, query): + schema = get_schema() + cached_backend = GraphQLCachedBackend(GraphQLCoreBackend()) + cached_backend.document_from_string(schema, query) # Prime cache + + @benchmark + def execute_query(): + result = schema.execute( + query, + context_value={"session": session_factory()}, + backend=cached_backend, + ) + assert not result.errors + + +def test_one_to_one(session_factory, benchmark): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + benchmark_query(session_factory, benchmark, """ + query { + reporters { + firstName + favoriteArticle { + headline + } + } + } + """) + + +def test_many_to_one(session_factory, benchmark): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_2 + session.add(article_2) + + session.commit() + session.close() + + benchmark_query(session_factory, benchmark, """ + query { + articles { + headline + reporter { + firstName + } + } + } + """) + + +def test_one_to_many(session_factory, benchmark): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + article_1 = Article(headline='Article_1') + article_1.reporter = reporter_1 + session.add(article_1) + + article_2 = Article(headline='Article_2') + article_2.reporter = reporter_1 + session.add(article_2) + + article_3 = Article(headline='Article_3') + article_3.reporter = reporter_2 + session.add(article_3) + + article_4 = Article(headline='Article_4') + article_4.reporter = reporter_2 + session.add(article_4) + + session.commit() + session.close() + + benchmark_query(session_factory, benchmark, """ + query { + reporters { + firstName + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + """) + + +def test_many_to_many(session_factory, benchmark): + session = session_factory() + + reporter_1 = Reporter( + first_name='Reporter_1', + ) + session.add(reporter_1) + reporter_2 = Reporter( + first_name='Reporter_2', + ) + session.add(reporter_2) + + pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_1) + + pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_2) + + reporter_1.pets.append(pet_1) + reporter_1.pets.append(pet_2) + + pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_3) + + pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) + session.add(pet_4) + + reporter_2.pets.append(pet_3) + reporter_2.pets.append(pet_4) + + session.commit() + session.close() + + benchmark_query(session_factory, benchmark, """ + query { + reporters { + firstName + pets(first: 2) { + edges { + node { + name + } + } + } + } + } + """) diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 5cc16e7..f0fc180 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,11 +1,11 @@ import enum -from py.test import raises -from sqlalchemy import Column, Table, case, func, select, types +import pytest +from sqlalchemy import Column, func, select, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.inspection import inspect from sqlalchemy.orm import column_property, composite -from sqlalchemy.sql.elements import Label from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType import graphene @@ -18,170 +18,185 @@ convert_sqlalchemy_relationship) from ..fields import (UnsortedSQLAlchemyConnectionField, default_connection_field_factory) -from ..registry import Registry +from ..registry import Registry, get_global_registry from ..types import SQLAlchemyObjectType -from .models import Article, Pet, Reporter - - -def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs): - column = Column(sqlalchemy_type, doc="Custom Help Text", **kwargs) - graphene_type = convert_sqlalchemy_column(column) - assert isinstance(graphene_type, graphene_field) - field = ( - graphene_type - if isinstance(graphene_type, graphene.Field) - else graphene_type.Field() - ) - assert field.description == "Custom Help Text" - return field - - -def assert_composite_conversion( - composite_class, composite_columns, graphene_field, registry, **kwargs -): - composite_column = composite( - composite_class, *composite_columns, doc="Custom Help Text", **kwargs - ) - graphene_type = convert_sqlalchemy_composite(composite_column, registry) - assert isinstance(graphene_type, graphene_field) - field = graphene_type.Field() - # SQLAlchemy currently does not persist the doc onto the column, even though - # the documentation says it does.... - # assert field.description == 'Custom Help Text' - return field +from .models import Article, CompositeFullName, Pet, Reporter + + +def mock_resolver(): + pass + + +def get_field(sqlalchemy_type, **column_kwargs): + class Model(declarative_base()): + __tablename__ = 'model' + id_ = Column(types.Integer, primary_key=True) + column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs) + + column_prop = inspect(Model).column_attrs['column'] + return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) + + +def get_field_from_column(column_): + class Model(declarative_base()): + __tablename__ = 'model' + id_ = Column(types.Integer, primary_key=True) + column = column_ + + column_prop = inspect(Model).column_attrs['column'] + return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) def test_should_unknown_sqlalchemy_field_raise_exception(): - with raises(Exception) as excinfo: - convert_sqlalchemy_column(None) - assert "Don't know how to convert the SQLAlchemy field" in str(excinfo.value) + re_err = "Don't know how to convert the SQLAlchemy field" + with pytest.raises(Exception, match=re_err): + # support legacy Binary type and subsequent LargeBinary + get_field(getattr(types, 'LargeBinary', types.Binary)()) def test_should_date_convert_string(): - assert_column_conversion(types.Date(), graphene.String) - - -def test_should_datetime_convert_string(): - assert_column_conversion(types.DateTime(), DateTime) + assert get_field(types.Date()).type == graphene.String + + +def test_should_datetime_convert_datetime(): + assert get_field(types.DateTime()).type == DateTime def test_should_time_convert_string(): - assert_column_conversion(types.Time(), graphene.String) + assert get_field(types.Time()).type == graphene.String def test_should_string_convert_string(): - assert_column_conversion(types.String(), graphene.String) + assert get_field(types.String()).type == graphene.String def test_should_text_convert_string(): - assert_column_conversion(types.Text(), graphene.String) + assert get_field(types.Text()).type == graphene.String def test_should_unicode_convert_string(): - assert_column_conversion(types.Unicode(), graphene.String) + assert get_field(types.Unicode()).type == graphene.String def test_should_unicodetext_convert_string(): - assert_column_conversion(types.UnicodeText(), graphene.String) + assert get_field(types.UnicodeText()).type == graphene.String def test_should_enum_convert_enum(): - field = assert_column_conversion( - types.Enum(enum.Enum("one", "two")), graphene.Field - ) + field = get_field(types.Enum(enum.Enum("TwoNumbers", ("one", "two")))) field_type = field.type() assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") - field = assert_column_conversion( - types.Enum("one", "two", name="two_numbers"), graphene.Field - ) + assert field_type._meta.name == "TwoNumbers" + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") + + field = get_field(types.Enum("one", "two", name="two_numbers")) field_type = field.type() - assert field_type.__class__.__name__ == "two_numbers" assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert field_type._meta.name == "TwoNumbers" + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") + + +def test_should_not_enum_convert_enum_without_name(): + field = get_field(types.Enum("one", "two")) + re_err = r"No type name specified for Enum\('one', 'two'\)" + with pytest.raises(TypeError, match=re_err): + field.type() def test_should_small_integer_convert_int(): - assert_column_conversion(types.SmallInteger(), graphene.Int) + assert get_field(types.SmallInteger()).type == graphene.Int def test_should_big_integer_convert_int(): - assert_column_conversion(types.BigInteger(), graphene.Float) + assert get_field(types.BigInteger()).type == graphene.Float def test_should_integer_convert_int(): - assert_column_conversion(types.Integer(), graphene.Int) - - -def test_should_integer_convert_id(): - assert_column_conversion(types.Integer(), graphene.ID, primary_key=True) + assert get_field(types.Integer()).type == graphene.Int + + +def test_should_primary_integer_convert_id(): + assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull(graphene.ID) def test_should_boolean_convert_boolean(): - assert_column_conversion(types.Boolean(), graphene.Boolean) + assert get_field(types.Boolean()).type == graphene.Boolean def test_should_float_convert_float(): - assert_column_conversion(types.Float(), graphene.Float) + assert get_field(types.Float()).type == graphene.Float def test_should_numeric_convert_float(): - assert_column_conversion(types.Numeric(), graphene.Float) - - -def test_should_label_convert_string(): - label = Label("label_test", case([], else_="foo"), type_=types.Unicode()) - graphene_type = convert_sqlalchemy_column(label) - assert isinstance(graphene_type, graphene.String) - - -def test_should_label_convert_int(): - label = Label("int_label_test", case([], else_="foo"), type_=types.Integer()) - graphene_type = convert_sqlalchemy_column(label) - assert isinstance(graphene_type, graphene.Int) + assert get_field(types.Numeric()).type == graphene.Float def test_should_choice_convert_enum(): - TYPES = [(u"es", u"Spanish"), (u"en", u"English")] - column = Column(ChoiceType(TYPES), doc="Language", name="language") - Base = declarative_base() - - Table("translatedmodel", Base.metadata, column) - graphene_type = convert_sqlalchemy_column(column) + field = get_field(ChoiceType([(u"es", u"Spanish"), (u"en", u"English")])) + graphene_type = field.type assert issubclass(graphene_type, graphene.Enum) - assert graphene_type._meta.name == "TRANSLATEDMODEL_LANGUAGE" - assert graphene_type._meta.description == "Language" + assert graphene_type._meta.name == "MODEL_COLUMN" assert graphene_type._meta.enum.__members__["es"].value == "Spanish" assert graphene_type._meta.enum.__members__["en"].value == "English" +def test_should_enum_choice_convert_enum(): + class TestEnum(enum.Enum): + es = u"Spanish" + en = u"English" + + field = get_field(ChoiceType(TestEnum, impl=types.String())) + graphene_type = field.type + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "MODEL_COLUMN" + assert graphene_type._meta.enum.__members__["es"].value == "Spanish" + assert graphene_type._meta.enum.__members__["en"].value == "English" + + +def test_should_intenum_choice_convert_enum(): + class TestEnum(enum.IntEnum): + one = 1 + two = 2 + + field = get_field(ChoiceType(TestEnum, impl=types.String())) + graphene_type = field.type + assert issubclass(graphene_type, graphene.Enum) + assert graphene_type._meta.name == "MODEL_COLUMN" + assert graphene_type._meta.enum.__members__["one"].value == 1 + assert graphene_type._meta.enum.__members__["two"].value == 2 + + def test_should_columproperty_convert(): - - Base = declarative_base() - - class Test(Base): - __tablename__ = "test" - id = Column(types.Integer, primary_key=True) - column = column_property( - select([func.sum(func.cast(id, types.Integer))]).where(id == 1) - ) - - graphene_type = convert_sqlalchemy_column(Test.column) - assert not graphene_type.kwargs["required"] + field = get_field_from_column(column_property( + select([func.sum(func.cast(id, types.Integer))]).where(id == 1) + )) + + assert field.type == graphene.Int def test_should_scalar_list_convert_list(): - assert_column_conversion(ScalarListType(), graphene.List) + field = get_field(ScalarListType()) + assert isinstance(field.type, graphene.List) + assert field.type.of_type == graphene.String def test_should_jsontype_convert_jsonstring(): - assert_column_conversion(JSONType(), JSONString) + assert get_field(JSONType()).type == JSONString def test_should_manytomany_convert_connectionorlist(): - registry = Registry() - dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, registry, default_connection_field_factory + class A(SQLAlchemyObjectType): + class Meta: + model = Article + + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -193,7 +208,7 @@ model = Pet dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry, default_connection_field_factory + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -209,16 +224,19 @@ interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.pets.property, A._meta.registry, default_connection_field_factory + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) def test_should_manytoone_convert_connectionorlist(): - registry = Registry() - dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, registry, default_connection_field_factory + class A(SQLAlchemyObjectType): + class Meta: + model = Article + + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -230,7 +248,7 @@ model = Reporter dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A._meta.registry, default_connection_field_factory + Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -245,7 +263,7 @@ interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Article.reporter.property, A._meta.registry, default_connection_field_factory + Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -260,7 +278,7 @@ interfaces = (Node,) dynamic_field = convert_sqlalchemy_relationship( - Reporter.favorite_article.property, A._meta.registry, default_connection_field_factory + Reporter.favorite_article.property, A, default_connection_field_factory, True, 'orm_field_name', ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() @@ -269,80 +287,85 @@ def test_should_postgresql_uuid_convert(): - assert_column_conversion(postgresql.UUID(), graphene.String) + assert get_field(postgresql.UUID()).type == graphene.String def test_should_postgresql_enum_convert(): - field = assert_column_conversion( - postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field - ) + field = get_field(postgresql.ENUM("one", "two", name="two_numbers")) field_type = field.type() - assert field_type.__class__.__name__ == "two_numbers" assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert field_type._meta.name == "TwoNumbers" + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") def test_should_postgresql_py_enum_convert(): - field = assert_column_conversion( - postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field - ) + field = get_field(postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers")) field_type = field.type() - assert field_type.__class__.__name__ == "TwoNumbers" + assert field_type._meta.name == "TwoNumbers" assert isinstance(field_type, graphene.Enum) - assert hasattr(field_type, "two") + assert hasattr(field_type, "ONE") + assert not hasattr(field_type, "one") + assert hasattr(field_type, "TWO") + assert not hasattr(field_type, "two") def test_should_postgresql_array_convert(): - assert_column_conversion(postgresql.ARRAY(types.Integer), graphene.List) + field = get_field(postgresql.ARRAY(types.Integer)) + assert isinstance(field.type, graphene.List) + assert field.type.of_type == graphene.Int + + +def test_should_array_convert(): + field = get_field(types.ARRAY(types.Integer)) + assert isinstance(field.type, graphene.List) + assert field.type.of_type == graphene.Int def test_should_postgresql_json_convert(): - assert_column_conversion(postgresql.JSON(), JSONString) + assert get_field(postgresql.JSON()).type == graphene.JSONString def test_should_postgresql_jsonb_convert(): - assert_column_conversion(postgresql.JSONB(), JSONString) + assert get_field(postgresql.JSONB()).type == graphene.JSONString def test_should_postgresql_hstore_convert(): - assert_column_conversion(postgresql.HSTORE(), JSONString) + assert get_field(postgresql.HSTORE()).type == graphene.JSONString def test_should_composite_convert(): - class CompositeClass(object): + registry = Registry() + + class CompositeClass: def __init__(self, col1, col2): self.col1 = col1 self.col2 = col2 - registry = Registry() - @convert_sqlalchemy_composite.register(CompositeClass, registry) def convert_composite_class(composite, registry): return graphene.String(description=composite.doc) - assert_composite_conversion( - CompositeClass, - (Column(types.Unicode(50)), Column(types.Unicode(50))), - graphene.String, + field = convert_sqlalchemy_composite( + composite(CompositeClass, (Column(types.Unicode(50)), Column(types.Unicode(50))), doc="Custom Help Text"), registry, - ) + mock_resolver, + ) + assert isinstance(field, graphene.String) def test_should_unknown_sqlalchemy_composite_raise_exception(): - registry = Registry() - - with raises(Exception) as excinfo: - - class CompositeClass(object): - def __init__(self, col1, col2): - self.col1 = col1 - self.col2 = col2 - - assert_composite_conversion( - CompositeClass, - (Column(types.Unicode(50)), Column(types.Unicode(50))), - graphene.String, - registry, + class CompositeClass: + def __init__(self, col1, col2): + self.col1 = col1 + self.col2 = col2 + + re_err = "Don't know how to convert the composite field" + with pytest.raises(Exception, match=re_err): + convert_sqlalchemy_composite( + composite(CompositeFullName, (Column(types.Unicode(50)), Column(types.Unicode(50)))), + Registry(), + mock_resolver, ) - - assert "Don't know how to convert the composite field" in str(excinfo.value) diff --git a/graphene_sqlalchemy/tests/test_enums.py b/graphene_sqlalchemy/tests/test_enums.py new file mode 100644 index 0000000..ca37696 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_enums.py @@ -0,0 +1,122 @@ +from enum import Enum as PyEnum + +import pytest +from sqlalchemy.types import Enum as SQLAlchemyEnumType + +from graphene import Enum + +from ..enums import _convert_sa_to_graphene_enum, enum_for_field +from ..types import SQLAlchemyObjectType +from .models import HairKind, Pet + + +def test_convert_sa_to_graphene_enum_bad_type(): + re_err = "Expected sqlalchemy.types.Enum, but got: 'foo'" + with pytest.raises(TypeError, match=re_err): + _convert_sa_to_graphene_enum("foo") + + +def test_convert_sa_to_graphene_enum_based_on_py_enum(): + class Color(PyEnum): + RED = 1 + GREEN = 2 + BLUE = 3 + + sa_enum = SQLAlchemyEnumType(Color) + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "Color" + assert graphene_enum._meta.enum is Color + + +def test_convert_sa_to_graphene_enum_based_on_py_enum_with_bad_names(): + class Color(PyEnum): + red = 1 + green = 2 + blue = 3 + + sa_enum = SQLAlchemyEnumType(Color) + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "Color" + assert graphene_enum._meta.enum is not Color + assert [ + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() + ] == [("RED", 1), ("GREEN", 2), ("BLUE", 3)] + + +def test_convert_sa_enum_to_graphene_enum_based_on_list_named(): + sa_enum = SQLAlchemyEnumType("red", "green", "blue", name="color_values") + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "ColorValues" + assert [ + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() + ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + + +def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): + sa_enum = SQLAlchemyEnumType("red", "green", "blue") + graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") + assert isinstance(graphene_enum, type(Enum)) + assert graphene_enum._meta.name == "FallbackName" + assert [ + (key, value.value) + for key, value in graphene_enum._meta.enum.__members__.items() + ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] + + +def test_convert_sa_enum_to_graphene_enum_based_on_list_without_name(): + sa_enum = SQLAlchemyEnumType("red", "green", "blue") + re_err = r"No type name specified for Enum\('red', 'green', 'blue'\)" + with pytest.raises(TypeError, match=re_err): + _convert_sa_to_graphene_enum(sa_enum) + + +def test_enum_for_field(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + enum = enum_for_field(PetType, 'pet_kind') + assert isinstance(enum, type(Enum)) + assert enum._meta.name == "PetKind" + assert [ + (key, value.value) + for key, value in enum._meta.enum.__members__.items() + ] == [("CAT", 'cat'), ("DOG", 'dog')] + enum2 = enum_for_field(PetType, 'pet_kind') + assert enum2 is enum + enum2 = PetType.enum_for_field('pet_kind') + assert enum2 is enum + + enum = enum_for_field(PetType, 'hair_kind') + assert isinstance(enum, type(Enum)) + assert enum._meta.name == "HairKind" + assert enum._meta.enum is HairKind + enum2 = PetType.enum_for_field('hair_kind') + assert enum2 is enum + + re_err = r"Cannot get PetType\.other_kind" + with pytest.raises(TypeError, match=re_err): + enum_for_field(PetType, 'other_kind') + with pytest.raises(TypeError, match=re_err): + PetType.enum_for_field('other_kind') + + re_err = r"PetType\.name does not map to enum column" + with pytest.raises(TypeError, match=re_err): + enum_for_field(PetType, 'name') + with pytest.raises(TypeError, match=re_err): + PetType.enum_for_field('name') + + re_err = r"Expected a field name, but got: None" + with pytest.raises(TypeError, match=re_err): + enum_for_field(PetType, None) + with pytest.raises(TypeError, match=re_err): + PetType.enum_for_field(None) + + re_err = "Expected SQLAlchemyObjectType, but got: None" + with pytest.raises(TypeError, match=re_err): + enum_for_field(None, 'other_kind') diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py index ff616b3..357055e 100644 --- a/graphene_sqlalchemy/tests/test_fields.py +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -1,40 +1,85 @@ import pytest +from promise import Promise -from graphene.relay import Connection +from graphene import NonNull, ObjectType +from graphene.relay import Connection, Node -from ..fields import SQLAlchemyConnectionField +from ..fields import (SQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField) from ..types import SQLAlchemyObjectType -from ..utils import sort_argument_for_model -from .models import Editor +from .models import Editor as EditorModel from .models import Pet as PetModel class Pet(SQLAlchemyObjectType): class Meta: model = PetModel + interfaces = (Node,) -class PetConn(Connection): +class Editor(SQLAlchemyObjectType): class Meta: - node = Pet + model = EditorModel + +## +# SQLAlchemyConnectionField +## + + +def test_nonnull_sqlalachemy_connection(): + field = SQLAlchemyConnectionField(NonNull(Pet.connection)) + assert isinstance(field.type, NonNull) + assert issubclass(field.type.of_type, Connection) + assert field.type.of_type._meta.node is Pet + + +def test_required_sqlalachemy_connection(): + field = SQLAlchemyConnectionField(Pet.connection, required=True) + assert isinstance(field.type, NonNull) + assert issubclass(field.type.of_type, Connection) + assert field.type.of_type._meta.node is Pet + + +def test_promise_connection_resolver(): + def resolver(_obj, _info): + return Promise.resolve([]) + + result = UnsortedSQLAlchemyConnectionField.connection_resolver( + resolver, Pet.connection, Pet, None, None + ) + assert isinstance(result, Promise) + + +def test_type_assert_sqlalchemy_object_type(): + with pytest.raises(AssertionError, match="only accepts SQLAlchemyObjectType"): + SQLAlchemyConnectionField(ObjectType).type + + +def test_type_assert_object_has_connection(): + with pytest.raises(AssertionError, match="doesn't have a connection"): + SQLAlchemyConnectionField(Editor).type + +## +# UnsortedSQLAlchemyConnectionField +## def test_sort_added_by_default(): - arg = SQLAlchemyConnectionField(PetConn) - assert "sort" in arg.args - assert arg.args["sort"] == sort_argument_for_model(PetModel) + field = SQLAlchemyConnectionField(Pet.connection) + assert "sort" in field.args + assert field.args["sort"] == Pet.sort_argument() def test_sort_can_be_removed(): - arg = SQLAlchemyConnectionField(PetConn, sort=None) - assert "sort" not in arg.args + field = SQLAlchemyConnectionField(Pet.connection, sort=None) + assert "sort" not in field.args def test_custom_sort(): - arg = SQLAlchemyConnectionField(PetConn, sort=sort_argument_for_model(Editor)) - assert arg.args["sort"] == sort_argument_for_model(Editor) + field = SQLAlchemyConnectionField(Pet.connection, sort=Editor.sort_argument()) + assert field.args["sort"] == Editor.sort_argument() -def test_init_raises(): - with pytest.raises(Exception, match="Cannot create sort"): +def test_sort_init_raises(): + with pytest.raises(TypeError, match="Cannot create sort"): SQLAlchemyConnectionField(Connection) diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 146c54e..3914081 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,55 +1,40 @@ -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker - import graphene -from graphene.relay import Connection, Node - +from graphene.relay import Node + +from ..converter import convert_sqlalchemy_composite from ..fields import SQLAlchemyConnectionField -from ..registry import reset_global_registry -from ..types import SQLAlchemyObjectType -from ..utils import sort_argument_for_model, sort_enum_for_model -from .models import Article, Base, Editor, Hairkind, Pet, Reporter - -db = create_engine("sqlite:///test_sqlalchemy.sqlite3") - - -@pytest.yield_fixture(scope="function") -def session(): - reset_global_registry() - connection = db.engine.connect() - transaction = connection.begin() - Base.metadata.create_all(connection) - - # options = dict(bind=connection, binds={}) - session_factory = sessionmaker(bind=connection) - session = scoped_session(session_factory) - - yield session - - # Finalize test here - transaction.rollback() - connection.close() - session.remove() - - -def setup_fixtures(session): - pet = Pet(name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG) +from ..types import ORMField, SQLAlchemyObjectType +from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter +from .utils import to_std_dicts + + +def add_test_data(session): + reporter = Reporter( + first_name='John', last_name='Doe', favorite_pet_kind='cat') + session.add(reporter) + pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) session.add(pet) - reporter = Reporter(first_name="ABA", last_name="X") - session.add(reporter) - reporter2 = Reporter(first_name="ABO", last_name="Y") - session.add(reporter2) - article = Article(headline="Hi!") + pet.reporters.append(reporter) + article = Article(headline='Hi!') article.reporter = reporter session.add(article) - editor = Editor(name="John") + reporter = Reporter( + first_name='Jane', last_name='Roe', favorite_pet_kind='dog') + session.add(reporter) + pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) + pet.reporters.append(reporter) + session.add(pet) + editor = Editor(name="Jack") session.add(editor) session.commit() -def test_should_query_well(session): - setup_fixtures(session) +def test_query_fields(session): + add_test_data(session) + + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return graphene.String() class ReporterType(SQLAlchemyObjectType): class Meta: @@ -59,18 +44,19 @@ reporter = graphene.Field(ReporterType) reporters = graphene.List(ReporterType) - def resolve_reporter(self, *args, **kwargs): + def resolve_reporter(self, _info): return session.query(Reporter).first() - def resolve_reporters(self, *args, **kwargs): + def resolve_reporters(self, _info): return session.query(Reporter) query = """ - query ReporterQuery { + query { reporter { - firstName, - lastName, - email + firstName + columnProp + hybridProp + compositeProp } reporters { firstName @@ -78,117 +64,23 @@ } """ expected = { - "reporter": {"firstName": "ABA", "lastName": "X", "email": None}, - "reporters": [{"firstName": "ABA"}, {"firstName": "ABO"}], + "reporter": { + "firstName": "John", + "hybridProp": "John", + "columnProp": 2, + "compositeProp": "John Doe", + }, + "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], } schema = graphene.Schema(query=Query) result = schema.execute(query) assert not result.errors - assert result.data == expected - - -def test_should_query_enums(session): - setup_fixtures(session) - - class PetType(SQLAlchemyObjectType): - class Meta: - model = Pet - - class Query(graphene.ObjectType): - pet = graphene.Field(PetType) - - def resolve_pet(self, *args, **kwargs): - return session.query(Pet).first() - - query = """ - query PetQuery { - pet { - name, - petKind - hairKind - } - } - """ - expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} - schema = graphene.Schema(query=Query) - result = schema.execute(query) - assert not result.errors - assert result.data == expected, result.data - - -def test_enum_parameter(session): - setup_fixtures(session) - - class PetType(SQLAlchemyObjectType): - class Meta: - model = Pet - - class Query(graphene.ObjectType): - pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type)) - - def resolve_pet(self, info, kind=None, *args, **kwargs): - query = session.query(Pet) - if kind: - query = query.filter(Pet.pet_kind == kind) - return query.first() - - query = """ - query PetQuery($kind: pet_kind) { - pet(kind: $kind) { - name, - petKind - hairKind - } - } - """ - expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} - schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "cat"}) - assert not result.errors - assert result.data == {"pet": None} - result = schema.execute(query, variables={"kind": "dog"}) - assert not result.errors - assert result.data == expected, result.data - - -def test_py_enum_parameter(session): - setup_fixtures(session) - - class PetType(SQLAlchemyObjectType): - class Meta: - model = Pet - - class Query(graphene.ObjectType): - pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['hair_kind'].type.of_type)) - - def resolve_pet(self, info, kind=None, *args, **kwargs): - query = session.query(Pet) - if kind: - # XXX Why kind passed in as a str instead of a Hairkind instance? - query = query.filter(Pet.hair_kind == Hairkind(kind)) - return query.first() - - query = """ - query PetQuery($kind: Hairkind) { - pet(kind: $kind) { - name, - petKind - hairKind - } - } - """ - expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} - schema = graphene.Schema(query=Query) - result = schema.execute(query, variables={"kind": "SHORT"}) - assert not result.errors - assert result.data == {"pet": None} - result = schema.execute(query, variables={"kind": "LONG"}) - assert not result.errors - assert result.data == expected, result.data - - -def test_should_node(session): - setup_fixtures(session) + result = to_std_dicts(result.data) + assert result == expected + + +def test_query_node(session): + add_test_data(session) class ReporterNode(SQLAlchemyObjectType): class Meta: @@ -204,31 +96,19 @@ model = Article interfaces = (Node,) - # @classmethod - # def get_node(cls, id, info): - # return Article(id=1, headline='Article node') - - class ArticleConnection(Connection): - class Meta: - node = ArticleNode - class Query(graphene.ObjectType): node = Node.Field() reporter = graphene.Field(ReporterNode) - article = graphene.Field(ArticleNode) - all_articles = SQLAlchemyConnectionField(ArticleConnection) - - def resolve_reporter(self, *args, **kwargs): + all_articles = SQLAlchemyConnectionField(ArticleNode.connection) + + def resolve_reporter(self, _info): return session.query(Reporter).first() - def resolve_article(self, *args, **kwargs): - return session.query(Article).first() - - query = """ - query ReporterQuery { + query = """ + query { reporter { - id, - firstName, + id + firstName articles { edges { node { @@ -236,8 +116,6 @@ } } } - lastName, - email } allArticles { edges { @@ -260,9 +138,7 @@ expected = { "reporter": { "id": "UmVwb3J0ZXJOb2RlOjE=", - "firstName": "ABA", - "lastName": "X", - "email": None, + "firstName": "John", "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, }, "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, @@ -271,31 +147,95 @@ schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected - - -def test_should_custom_identifier(session): - setup_fixtures(session) + result = to_std_dicts(result.data) + assert result == expected + + +def test_orm_field(session): + add_test_data(session) + + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return graphene.String() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + first_name_v2 = ORMField(model_attr='first_name') + hybrid_prop_v2 = ORMField(model_attr='hybrid_prop') + column_prop_v2 = ORMField(model_attr='column_prop') + composite_prop = ORMField() + favorite_article_v2 = ORMField(model_attr='favorite_article') + articles_v2 = ORMField(model_attr='articles') + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + + def resolve_reporter(self, _info): + return session.query(Reporter).first() + + query = """ + query { + reporter { + firstNameV2 + hybridPropV2 + columnPropV2 + compositeProp + favoriteArticleV2 { + headline + } + articlesV2(first: 1) { + edges { + node { + headline + } + } + } + } + } + """ + expected = { + "reporter": { + "firstNameV2": "John", + "hybridPropV2": "John", + "columnPropV2": 2, + "compositeProp": "John Doe", + "favoriteArticleV2": {"headline": "Hi!"}, + "articlesV2": {"edges": [{"node": {"headline": "Hi!"}}]}, + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +def test_custom_identifier(session): + add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: model = Editor interfaces = (Node,) - class EditorConnection(Connection): - class Meta: - node = EditorNode - class Query(graphene.ObjectType): node = Node.Field() - all_editors = SQLAlchemyConnectionField(EditorConnection) - - query = """ - query EditorQuery { + all_editors = SQLAlchemyConnectionField(EditorNode.connection) + + query = """ + query { allEditors { edges { node { - id, + id name } } @@ -308,18 +248,19 @@ } """ expected = { - "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "John"}}]}, - "node": {"name": "John"}, + "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "Jack"}}]}, + "node": {"name": "Jack"}, } schema = graphene.Schema(query=Query) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected - - -def test_should_mutate_well(session): - setup_fixtures(session) + result = to_std_dicts(result.data) + assert result == expected + + +def test_mutation(session): + add_test_data(session) class EditorNode(SQLAlchemyObjectType): class Meta: @@ -364,7 +305,7 @@ create_article = CreateArticle.Field() query = """ - mutation ArticleCreator { + mutation { createArticle( headline: "My Article" reporterId: "1" @@ -385,7 +326,7 @@ "ok": True, "article": { "headline": "My Article", - "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "ABA"}, + "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "John"}, }, } } @@ -393,165 +334,5 @@ schema = graphene.Schema(query=Query, mutation=Mutation) result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data == expected - - -def sort_setup(session): - pets = [ - Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG), - Pet(id=22, name="Alf", pet_kind="cat", hair_kind=Hairkind.LONG), - Pet(id=3, name="Barf", pet_kind="dog", hair_kind=Hairkind.LONG), - ] - session.add_all(pets) - session.commit() - - -def test_sort(session): - sort_setup(session) - - class PetNode(SQLAlchemyObjectType): - class Meta: - model = Pet - interfaces = (Node,) - - class PetConnection(Connection): - class Meta: - node = PetNode - - class Query(graphene.ObjectType): - defaultSort = SQLAlchemyConnectionField(PetConnection) - nameSort = SQLAlchemyConnectionField(PetConnection) - multipleSort = SQLAlchemyConnectionField(PetConnection) - descSort = SQLAlchemyConnectionField(PetConnection) - singleColumnSort = SQLAlchemyConnectionField( - PetConnection, sort=graphene.Argument(sort_enum_for_model(Pet)) - ) - noDefaultSort = SQLAlchemyConnectionField( - PetConnection, sort=sort_argument_for_model(Pet, False) - ) - noSort = SQLAlchemyConnectionField(PetConnection, sort=None) - - query = """ - query sortTest { - defaultSort{ - edges{ - node{ - id - } - } - } - nameSort(sort: name_asc){ - edges{ - node{ - name - } - } - } - multipleSort(sort: [pet_kind_asc, name_desc]){ - edges{ - node{ - name - petKind - } - } - } - descSort(sort: [name_desc]){ - edges{ - node{ - name - } - } - } - singleColumnSort(sort: name_desc){ - edges{ - node{ - name - } - } - } - noDefaultSort(sort: name_asc){ - edges{ - node{ - name - } - } - } - } - """ - - def makeNodes(nodeList): - nodes = [{"node": item} for item in nodeList] - return {"edges": nodes} - - expected = { - "defaultSort": makeNodes( - [{"id": "UGV0Tm9kZToy"}, {"id": "UGV0Tm9kZToz"}, {"id": "UGV0Tm9kZToyMg=="}] - ), - "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), - "noDefaultSort": makeNodes( - [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] - ), - "multipleSort": makeNodes( - [ - {"name": "Alf", "petKind": "cat"}, - {"name": "Lassie", "petKind": "dog"}, - {"name": "Barf", "petKind": "dog"}, - ] - ), - "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), - "singleColumnSort": makeNodes( - [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] - ), - } # yapf: disable - - schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={"session": session}) - assert not result.errors - assert result.data == expected - - queryError = """ - query sortTest { - singleColumnSort(sort: [pet_kind_asc, name_desc]){ - edges{ - node{ - name - } - } - } - } - """ - result = schema.execute(queryError, context_value={"session": session}) - assert result.errors is not None - - queryNoSort = """ - query sortTest { - noDefaultSort{ - edges{ - node{ - name - } - } - } - noSort{ - edges{ - node{ - name - } - } - } - } - """ - - expectedNoSort = { - "noDefaultSort": makeNodes( - [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] - ), - "noSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), - } # yapf: disable - - result = schema.execute(queryNoSort, context_value={"session": session}) - assert not result.errors - for key, value in result.data.items(): - assert set(node["node"]["name"] for node in value["edges"]) == set( - node["node"]["name"] for node in expectedNoSort[key]["edges"] - ) + result = to_std_dicts(result.data) + assert result == expected diff --git a/graphene_sqlalchemy/tests/test_query_enums.py b/graphene_sqlalchemy/tests/test_query_enums.py new file mode 100644 index 0000000..ec585d5 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_query_enums.py @@ -0,0 +1,198 @@ +import graphene + +from ..types import SQLAlchemyObjectType +from .models import HairKind, Pet, Reporter +from .test_query import add_test_data, to_std_dicts + + +def test_query_pet_kinds(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + + class Meta: + model = Pet + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + reporters = graphene.List(ReporterType) + pets = graphene.List(PetType, kind=graphene.Argument( + PetType.enum_for_field('pet_kind'))) + + def resolve_reporter(self, _info): + return session.query(Reporter).first() + + def resolve_reporters(self, _info): + return session.query(Reporter) + + def resolve_pets(self, _info, kind): + query = session.query(Pet) + if kind: + query = query.filter_by(pet_kind=kind) + return query + + query = """ + query ReporterQuery { + reporter { + firstName + lastName + email + favoritePetKind + pets { + name + petKind + } + } + reporters { + firstName + favoritePetKind + } + pets(kind: DOG) { + name + petKind + } + } + """ + expected = { + 'reporter': { + 'firstName': 'John', + 'lastName': 'Doe', + 'email': None, + 'favoritePetKind': 'CAT', + 'pets': [{ + 'name': 'Garfield', + 'petKind': 'CAT' + }] + }, + 'reporters': [{ + 'firstName': 'John', + 'favoritePetKind': 'CAT', + }, { + 'firstName': 'Jane', + 'favoritePetKind': 'DOG', + }], + 'pets': [{ + 'name': 'Lassie', + 'petKind': 'DOG' + }] + } + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_query_more_enums(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType) + + def resolve_pet(self, _info): + return session.query(Pet).first() + + query = """ + query PetQuery { + pet { + name, + petKind + hairKind + } + } + """ + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + +def test_enum_as_argument(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field( + PetType, + kind=graphene.Argument(PetType.enum_for_field('pet_kind'))) + + def resolve_pet(self, info, kind=None): + query = session.query(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind) + return query.first() + + query = """ + query PetQuery($kind: PetKind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "CAT"}) + assert not result.errors + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + assert result.data == expected + result = schema.execute(query, variables={"kind": "DOG"}) + assert not result.errors + expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} + result = to_std_dicts(result.data) + assert result == expected + + +def test_py_enum_as_argument(session): + add_test_data(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field( + PetType, + kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type), + ) + + def resolve_pet(self, _info, kind=None): + query = session.query(Pet) + if kind: + # enum arguments are expected to be strings, not PyEnums + query = query.filter(Pet.hair_kind == HairKind(kind)) + return query.first() + + query = """ + query PetQuery($kind: HairKind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "SHORT"}) + assert not result.errors + expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} + assert result.data == expected + result = schema.execute(query, variables={"kind": "LONG"}) + assert not result.errors + expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} + result = to_std_dicts(result.data) + assert result == expected diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py index 1945af6..0403c4f 100644 --- a/graphene_sqlalchemy/tests/test_registry.py +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -1,25 +1,15 @@ import pytest +from sqlalchemy.types import Enum as SQLAlchemyEnum + +from graphene import Enum as GrapheneEnum from ..registry import Registry from ..types import SQLAlchemyObjectType +from ..utils import EnumValue from .models import Pet -def test_register_incorrect_objecttype(): - reg = Registry() - - class Spam: - pass - - with pytest.raises(AssertionError) as excinfo: - reg.register(Spam) - - assert "Only classes of type SQLAlchemyObjectType can be registered" in str( - excinfo.value - ) - - -def test_register_objecttype(): +def test_register_object_type(): reg = Registry() class PetType(SQLAlchemyObjectType): @@ -27,7 +17,112 @@ model = Pet registry = reg - try: - reg.register(PetType) - except AssertionError: - pytest.fail("expected no AssertionError") + reg.register(PetType) + assert reg.get_type_for_model(Pet) is PetType + + +def test_register_incorrect_object_type(): + reg = Registry() + + class Spam: + pass + + re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" + with pytest.raises(TypeError, match=re_err): + reg.register(Spam) + + +def test_register_orm_field(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + reg.register_orm_field(PetType, "name", Pet.name) + assert reg.get_orm_field_for_graphene_field(PetType, "name") is Pet.name + + +def test_register_orm_field_incorrect_types(): + reg = Registry() + + class Spam: + pass + + re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" + with pytest.raises(TypeError, match=re_err): + reg.register_orm_field(Spam, "name", Pet.name) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + re_err = "Expected a field name, but got: .*Spam" + with pytest.raises(TypeError, match=re_err): + reg.register_orm_field(PetType, Spam, Pet.name) + + +def test_register_enum(): + reg = Registry() + + sa_enum = SQLAlchemyEnum("cat", "dog") + graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)]) + + reg.register_enum(sa_enum, graphene_enum) + assert reg.get_graphene_enum_for_sa_enum(sa_enum) is graphene_enum + + +def test_register_enum_incorrect_types(): + reg = Registry() + + sa_enum = SQLAlchemyEnum("cat", "dog") + graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)]) + + re_err = r"Expected Graphene Enum, but got: Enum\('cat', 'dog'\)" + with pytest.raises(TypeError, match=re_err): + reg.register_enum(sa_enum, sa_enum) + + re_err = r"Expected SQLAlchemyEnumType, but got: .*PetKind.*" + with pytest.raises(TypeError, match=re_err): + reg.register_enum(graphene_enum, graphene_enum) + + +def test_register_sort_enum(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + sort_enum = GrapheneEnum( + "PetSort", + [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))], + ) + + reg.register_sort_enum(PetType, sort_enum) + assert reg.get_sort_enum_for_object_type(PetType) is sort_enum + + +def test_register_sort_enum_incorrect_types(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + sort_enum = GrapheneEnum( + "PetSort", + [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))], + ) + + re_err = r"Expected SQLAlchemyObjectType, but got: .*PetSort.*" + with pytest.raises(TypeError, match=re_err): + reg.register_sort_enum(sort_enum, sort_enum) + + re_err = r"Expected Graphene Enum, but got: .*PetType.*" + with pytest.raises(TypeError, match=re_err): + reg.register_sort_enum(PetType, PetType) diff --git a/graphene_sqlalchemy/tests/test_schema.py b/graphene_sqlalchemy/tests/test_schema.py deleted file mode 100644 index 628da18..0000000 --- a/graphene_sqlalchemy/tests/test_schema.py +++ /dev/null @@ -1,49 +0,0 @@ -from py.test import raises - -from ..registry import Registry -from ..types import SQLAlchemyObjectType -from .models import Reporter - - -def test_should_raise_if_no_model(): - with raises(Exception) as excinfo: - - class Character1(SQLAlchemyObjectType): - pass - - assert "valid SQLAlchemy Model" in str(excinfo.value) - - -def test_should_raise_if_model_is_invalid(): - with raises(Exception) as excinfo: - - class Character2(SQLAlchemyObjectType): - class Meta: - model = 1 - - assert "valid SQLAlchemy Model" in str(excinfo.value) - - -def test_should_map_fields_correctly(): - class ReporterType2(SQLAlchemyObjectType): - class Meta: - model = Reporter - registry = Registry() - - assert list(ReporterType2._meta.fields.keys()) == [ - "id", - "first_name", - "last_name", - "email", - "pets", - "articles", - "favorite_article", - ] - - -def test_should_map_only_few_fields(): - class Reporter2(SQLAlchemyObjectType): - class Meta: - model = Reporter - only_fields = ("id", "email") - assert list(Reporter2._meta.fields.keys()) == ["id", "email"] diff --git a/graphene_sqlalchemy/tests/test_sort_enums.py b/graphene_sqlalchemy/tests/test_sort_enums.py new file mode 100644 index 0000000..d6f6965 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_sort_enums.py @@ -0,0 +1,385 @@ +import pytest +import sqlalchemy as sa + +from graphene import Argument, Enum, List, ObjectType, Schema +from graphene.relay import Node + +from ..fields import SQLAlchemyConnectionField +from ..types import SQLAlchemyObjectType +from ..utils import to_type_name +from .models import Base, HairKind, Pet +from .test_query import to_std_dicts + + +def add_pets(session): + pets = [ + Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG), + Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG), + Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG), + ] + session.add_all(pets) + session.commit() + + +def test_sort_enum(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum() + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "PetTypeSortEnum" + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + "HAIR_KIND_ASC", + "HAIR_KIND_DESC", + "REPORTER_ID_ASC", + "REPORTER_ID_DESC", + ] + assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" + assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" + assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC" + assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC" + + +def test_sort_enum_with_custom_name(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum(name="CustomSortName") + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "CustomSortName" + + +def test_sort_enum_cache(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum() + sort_enum_2 = PetType.sort_enum() + assert sort_enum_2 is sort_enum + sort_enum_2 = PetType.sort_enum(name="PetTypeSortEnum") + assert sort_enum_2 is sort_enum + err_msg = "Sort enum for PetType has already been customized" + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(name="CustomSortName") + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(only_fields=["id"]) + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(only_indexed=True) + with pytest.raises(ValueError, match=err_msg): + PetType.sort_enum(get_symbol_name=lambda: "foo") + + +def test_sort_enum_with_excluded_field_in_object_type(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + exclude_fields = ["reporter_id"] + + sort_enum = PetType.sort_enum() + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + "HAIR_KIND_ASC", + "HAIR_KIND_DESC", + ] + + +def test_sort_enum_only_fields(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_enum = PetType.sort_enum(only_fields=["id", "name"]) + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + ] + + +def test_sort_argument(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + sort_arg = PetType.sort_argument() + assert isinstance(sort_arg, Argument) + + assert isinstance(sort_arg.type, List) + sort_enum = sort_arg.type._of_type + assert isinstance(sort_enum, type(Enum)) + assert sort_enum._meta.name == "PetTypeSortEnum" + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + "HAIR_KIND_ASC", + "HAIR_KIND_DESC", + "REPORTER_ID_ASC", + "REPORTER_ID_DESC", + ] + assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" + assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" + assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC" + assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC" + + assert sort_arg.default_value == ["ID_ASC"] + assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" + + +def test_sort_argument_with_excluded_fields_in_object_type(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + exclude_fields = ["hair_kind", "reporter_id"] + + sort_arg = PetType.sort_argument() + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "NAME_ASC", + "NAME_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + ] + assert sort_arg.default_value == ["ID_ASC"] + + +def test_sort_argument_only_fields(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + only_fields = ["id", "pet_kind"] + + sort_arg = PetType.sort_argument() + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "PET_KIND_ASC", + "PET_KIND_DESC", + ] + assert sort_arg.default_value == ["ID_ASC"] + + +def test_sort_argument_for_multi_column_pk(): + class MultiPkTestModel(Base): + __tablename__ = "multi_pk_test_table" + foo = sa.Column(sa.Integer, primary_key=True) + bar = sa.Column(sa.Integer, primary_key=True) + + class MultiPkTestType(SQLAlchemyObjectType): + class Meta: + model = MultiPkTestModel + + sort_arg = MultiPkTestType.sort_argument() + assert sort_arg.default_value == ["FOO_ASC", "BAR_ASC"] + + +def test_sort_argument_only_indexed(): + class IndexedTestModel(Base): + __tablename__ = "indexed_test_table" + id = sa.Column(sa.Integer, primary_key=True) + foo = sa.Column(sa.Integer, index=False) + bar = sa.Column(sa.Integer, index=True) + + class IndexedTestType(SQLAlchemyObjectType): + class Meta: + model = IndexedTestModel + + sort_arg = IndexedTestType.sort_argument(only_indexed=True) + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "ID_ASC", + "ID_DESC", + "BAR_ASC", + "BAR_DESC", + ] + assert sort_arg.default_value == ["ID_ASC"] + + +def test_sort_argument_with_custom_symbol_names(): + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + def get_symbol_name(column_name, sort_asc=True): + return to_type_name(column_name) + ("Up" if sort_asc else "Down") + + sort_arg = PetType.sort_argument(get_symbol_name=get_symbol_name) + sort_enum = sort_arg.type._of_type + assert list(sort_enum._meta.enum.__members__) == [ + "IdUp", + "IdDown", + "NameUp", + "NameDown", + "PetKindUp", + "PetKindDown", + "HairKindUp", + "HairKindDown", + "ReporterIdUp", + "ReporterIdDown", + ] + assert sort_arg.default_value == ["IdUp"] + + +def test_sort_query(session): + add_pets(session) + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node,) + + class Query(ObjectType): + defaultSort = SQLAlchemyConnectionField(PetNode.connection) + nameSort = SQLAlchemyConnectionField(PetNode.connection) + multipleSort = SQLAlchemyConnectionField(PetNode.connection) + descSort = SQLAlchemyConnectionField(PetNode.connection) + singleColumnSort = SQLAlchemyConnectionField( + PetNode.connection, sort=Argument(PetNode.sort_enum()) + ) + noDefaultSort = SQLAlchemyConnectionField( + PetNode.connection, sort=PetNode.sort_argument(has_default=False) + ) + noSort = SQLAlchemyConnectionField(PetNode.connection, sort=None) + + query = """ + query sortTest { + defaultSort { + edges { + node { + name + } + } + } + nameSort(sort: NAME_ASC) { + edges { + node { + name + } + } + } + multipleSort(sort: [PET_KIND_ASC, NAME_DESC]) { + edges { + node { + name + petKind + } + } + } + descSort(sort: [NAME_DESC]) { + edges { + node { + name + } + } + } + singleColumnSort(sort: NAME_DESC) { + edges { + node { + name + } + } + } + noDefaultSort(sort: NAME_ASC) { + edges { + node { + name + } + } + } + } + """ + + def makeNodes(nodeList): + nodes = [{"node": item} for item in nodeList] + return {"edges": nodes} + + expected = { + "defaultSort": makeNodes( + [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] + ), + "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), + "noDefaultSort": makeNodes( + [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] + ), + "multipleSort": makeNodes( + [ + {"name": "Alf", "petKind": "CAT"}, + {"name": "Lassie", "petKind": "DOG"}, + {"name": "Barf", "petKind": "DOG"}, + ] + ), + "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), + "singleColumnSort": makeNodes( + [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] + ), + } # yapf: disable + + schema = Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + result = to_std_dicts(result.data) + assert result == expected + + queryError = """ + query sortTest { + singleColumnSort(sort: [PET_KIND_ASC, NAME_DESC]) { + edges { + node { + name + } + } + } + } + """ + result = schema.execute(queryError, context_value={"session": session}) + assert result.errors is not None + assert '"sort" has invalid value' in result.errors[0].message + + queryNoSort = """ + query sortTest { + noDefaultSort { + edges { + node { + name + } + } + } + noSort { + edges { + node { + name + } + } + } + } + """ + + result = schema.execute(queryNoSort, context_value={"session": session}) + assert not result.errors + # TODO: SQLite usually returns the results ordered by primary key, + # so we cannot test this way whether sorting actually happens or not. + # Also, no sort order is guaranteed by SQLite if "no order" by is used. + assert [node["node"]["name"] for node in result.data["noSort"]["edges"]] == [ + node["node"]["name"] for node in result.data["noDefaultSort"]["edges"] + ] diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 0360a64..bf563b6 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,193 +1,417 @@ -from collections import OrderedDict - +import mock +import pytest import six # noqa F401 -from promise import Promise - -from graphene import (Connection, Field, Int, Interface, Node, ObjectType, - is_node) - + +from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, + ObjectType, Schema, String) +from graphene.relay import Connection + +from ..converter import convert_sqlalchemy_composite from ..fields import (SQLAlchemyConnectionField, - UnsortedSQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField, createConnectionField, registerConnectionFieldFactory, unregisterConnectionFieldFactory) -from ..registry import Registry -from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions -from .models import Article, Reporter - -registry = Registry() - - -class Character(SQLAlchemyObjectType): - """Character description""" - - class Meta: - model = Reporter - registry = registry - - -class Human(SQLAlchemyObjectType): - """Human description""" - - pub_date = Int() - - class Meta: - model = Article - exclude_fields = ("id",) - registry = registry - interfaces = (Node,) - - -def test_sqlalchemy_interface(): - assert issubclass(Node, Interface) - assert issubclass(Node, Node) - - -# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) -# def test_sqlalchemy_get_node(get): -# human = Human.get_node(1, None) -# get.assert_called_with(id=1) -# assert human.id == 1 - - -def test_objecttype_registered(): - assert issubclass(Character, ObjectType) - assert Character._meta.model == Reporter - assert list(Character._meta.fields.keys()) == [ +from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions +from .models import Article, CompositeFullName, Pet, Reporter + + +def test_should_raise_if_no_model(): + re_err = r"valid SQLAlchemy Model" + with pytest.raises(Exception, match=re_err): + class Character1(SQLAlchemyObjectType): + pass + + +def test_should_raise_if_model_is_invalid(): + re_err = r"valid SQLAlchemy Model" + with pytest.raises(Exception, match=re_err): + class Character(SQLAlchemyObjectType): + class Meta: + model = 1 + + +def test_sqlalchemy_node(session): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + reporter_id_field = ReporterType._meta.fields["id"] + assert isinstance(reporter_id_field, GlobalID) + + reporter = Reporter() + session.add(reporter) + session.commit() + info = mock.Mock(context={'session': session}) + reporter_node = ReporterType.get_node(info, reporter.id) + assert reporter == reporter_node + + +def test_connection(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + assert issubclass(ReporterType.connection, Connection) + + +def test_sqlalchemy_default_fields(): + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return String() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + assert list(ReporterType._meta.fields.keys()) == [ + # Columns + "column_prop", # SQLAlchemy retuns column properties first "id", "first_name", "last_name", "email", + "favorite_pet_kind", + # Composite + "composite_prop", + # Hybrid + "hybrid_prop", + # Relationship "pets", "articles", "favorite_article", ] - -# def test_sqlalchemynode_idfield(): -# idfield = Node._meta.fields_map['id'] -# assert isinstance(idfield, GlobalIDField) - - -# def test_node_idfield(): -# idfield = Human._meta.fields_map['id'] -# assert isinstance(idfield, GlobalIDField) - - -def test_node_replacedfield(): - idfield = Human._meta.fields["pub_date"] - assert isinstance(idfield, Field) - assert idfield.type == Int - - -def test_object_type(): - class Human(SQLAlchemyObjectType): - """Human description""" - - pub_date = Int() - + # column + first_name_field = ReporterType._meta.fields['first_name'] + assert first_name_field.type == String + assert first_name_field.description == "First name" + + # column_property + column_prop_field = ReporterType._meta.fields['column_prop'] + assert column_prop_field.type == Int + # "doc" is ignored by column_property + assert column_prop_field.description is None + + # composite + full_name_field = ReporterType._meta.fields['composite_prop'] + assert full_name_field.type == String + # "doc" is ignored by composite + assert full_name_field.description is None + + # hybrid_property + hybrid_prop = ReporterType._meta.fields['hybrid_prop'] + assert hybrid_prop.type == String + # "doc" is ignored by hybrid_property + assert hybrid_prop.description is None + + # relationship + favorite_article_field = ReporterType._meta.fields['favorite_article'] + assert isinstance(favorite_article_field, Dynamic) + assert favorite_article_field.type().type == ArticleType + assert favorite_article_field.type().description is None + + +def test_sqlalchemy_override_fields(): + @convert_sqlalchemy_composite.register(CompositeFullName) + def convert_composite_class(composite, registry): + return String() + + class ReporterMixin(object): + # columns + first_name = ORMField(required=True) + last_name = ORMField(description='Overridden') + + class ReporterType(SQLAlchemyObjectType, ReporterMixin): + class Meta: + model = Reporter + interfaces = (Node,) + + # columns + email = ORMField(deprecation_reason='Overridden') + email_v2 = ORMField(model_attr='email', type=Int) + + # column_property + column_prop = ORMField(type=String) + + # composite + composite_prop = ORMField() + + # hybrid_property + hybrid_prop = ORMField(description='Overridden') + + # relationships + favorite_article = ORMField(description='Overridden') + articles = ORMField(deprecation_reason='Overridden') + pets = ORMField(description='Overridden') + + class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - # exclude_fields = ('id', ) - registry = registry - interfaces = (Node,) - - assert issubclass(Human, ObjectType) - assert list(Human._meta.fields.keys()) == [ - "id", - "headline", - "pub_date", - "reporter_id", - "reporter", - ] - assert is_node(Human) - - -# Test Custom SQLAlchemyObjectType Implementation -class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): - class Meta: - abstract = True - - -class CustomCharacter(CustomSQLAlchemyObjectType): - """Character description""" - - class Meta: - model = Reporter - registry = registry - - -def test_custom_objecttype_registered(): - assert issubclass(CustomCharacter, ObjectType) - assert CustomCharacter._meta.model == Reporter - assert list(CustomCharacter._meta.fields.keys()) == [ - "id", + interfaces = (Node,) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node,) + use_connection = False + + assert list(ReporterType._meta.fields.keys()) == [ + # Fields from ReporterMixin "first_name", "last_name", + # Fields from ReporterType "email", + "email_v2", + "column_prop", + "composite_prop", + "hybrid_prop", + "favorite_article", + "articles", + "pets", + # Then the automatic SQLAlchemy fields + "id", + "favorite_pet_kind", + ] + + first_name_field = ReporterType._meta.fields['first_name'] + assert isinstance(first_name_field.type, NonNull) + assert first_name_field.type.of_type == String + assert first_name_field.description == "First name" + assert first_name_field.deprecation_reason is None + + last_name_field = ReporterType._meta.fields['last_name'] + assert last_name_field.type == String + assert last_name_field.description == "Overridden" + assert last_name_field.deprecation_reason is None + + email_field = ReporterType._meta.fields['email'] + assert email_field.type == String + assert email_field.description == "Email" + assert email_field.deprecation_reason == "Overridden" + + email_field_v2 = ReporterType._meta.fields['email_v2'] + assert email_field_v2.type == Int + assert email_field_v2.description == "Email" + assert email_field_v2.deprecation_reason is None + + hybrid_prop_field = ReporterType._meta.fields['hybrid_prop'] + assert hybrid_prop_field.type == String + assert hybrid_prop_field.description == "Overridden" + assert hybrid_prop_field.deprecation_reason is None + + column_prop_field_v2 = ReporterType._meta.fields['column_prop'] + assert column_prop_field_v2.type == String + assert column_prop_field_v2.description is None + assert column_prop_field_v2.deprecation_reason is None + + composite_prop_field = ReporterType._meta.fields['composite_prop'] + assert composite_prop_field.type == String + assert composite_prop_field.description is None + assert composite_prop_field.deprecation_reason is None + + favorite_article_field = ReporterType._meta.fields['favorite_article'] + assert isinstance(favorite_article_field, Dynamic) + assert favorite_article_field.type().type == ArticleType + assert favorite_article_field.type().description == 'Overridden' + + articles_field = ReporterType._meta.fields['articles'] + assert isinstance(articles_field, Dynamic) + assert isinstance(articles_field.type(), UnsortedSQLAlchemyConnectionField) + assert articles_field.type().deprecation_reason == "Overridden" + + pets_field = ReporterType._meta.fields['pets'] + assert isinstance(pets_field, Dynamic) + assert isinstance(pets_field.type().type, List) + assert pets_field.type().type.of_type == PetType + assert pets_field.type().description == 'Overridden' + + +def test_invalid_model_attr(): + err_msg = ( + "Cannot map ORMField to a model attribute.\n" + "Field: 'ReporterType.first_name'" + ) + with pytest.raises(ValueError, match=err_msg): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + first_name = ORMField(model_attr='does_not_exist') + + +def test_only_fields(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + only_fields = ("id", "last_name") + + first_name = ORMField() # Takes precedence + last_name = ORMField() # Noop + + assert list(ReporterType._meta.fields.keys()) == ["first_name", "last_name", "id"] + + +def test_exclude_fields(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + exclude_fields = ("id", "first_name") + + first_name = ORMField() # Takes precedence + last_name = ORMField() # Noop + + assert list(ReporterType._meta.fields.keys()) == [ + "first_name", + "last_name", + "column_prop", + "email", + "favorite_pet_kind", + "composite_prop", + "hybrid_prop", "pets", "articles", "favorite_article", ] +def test_only_and_exclude_fields(): + re_err = r"'only_fields' and 'exclude_fields' cannot be both set" + with pytest.raises(Exception, match=re_err): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + only_fields = ("id", "last_name") + exclude_fields = ("id", "last_name") + + +def test_sqlalchemy_redefine_field(): + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + + first_name = Int() + + first_name_field = ReporterType._meta.fields["first_name"] + assert isinstance(first_name_field, Field) + assert first_name_field.type == Int + + +def test_resolvers(session): + """Test that the correct resolver functions are called""" + + class ReporterMixin(object): + def resolve_id(root, _info): + return 'ID' + + class ReporterType(ReporterMixin, SQLAlchemyObjectType): + class Meta: + model = Reporter + + email = ORMField() + email_v2 = ORMField(model_attr='email') + favorite_pet_kind = Field(String) + favorite_pet_kind_v2 = Field(String) + + def resolve_last_name(root, _info): + return root.last_name.upper() + + def resolve_email_v2(root, _info): + return root.email + '_V2' + + def resolve_favorite_pet_kind_v2(root, _info): + return str(root.favorite_pet_kind) + '_V2' + + class Query(ObjectType): + reporter = Field(ReporterType) + + def resolve_reporter(self, _info): + return session.query(Reporter).first() + + reporter = Reporter(first_name='first_name', last_name='last_name', email='email', favorite_pet_kind='cat') + session.add(reporter) + session.commit() + + schema = Schema(query=Query) + result = schema.execute(""" + query { + reporter { + id + firstName + lastName + email + emailV2 + favoritePetKind + favoritePetKindV2 + } + } + """) + + assert not result.errors + # Custom resolver on a base class + assert result.data['reporter']['id'] == 'ID' + # Default field + default resolver + assert result.data['reporter']['firstName'] == 'first_name' + # Default field + custom resolver + assert result.data['reporter']['lastName'] == 'LAST_NAME' + # ORMField + default resolver + assert result.data['reporter']['email'] == 'email' + # ORMField + custom resolver + assert result.data['reporter']['emailV2'] == 'email_V2' + # Field + default resolver + assert result.data['reporter']['favoritePetKind'] == 'cat' + # Field + custom resolver + assert result.data['reporter']['favoritePetKindV2'] == 'cat_V2' + + +# Test Custom SQLAlchemyObjectType Implementation + +def test_custom_objecttype_registered(): + class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): + class Meta: + abstract = True + + class CustomReporterType(CustomSQLAlchemyObjectType): + class Meta: + model = Reporter + + assert issubclass(CustomReporterType, ObjectType) + assert CustomReporterType._meta.model == Reporter + assert len(CustomReporterType._meta.fields) == 11 + + # Test Custom SQLAlchemyObjectType with Custom Options -class CustomOptions(SQLAlchemyObjectTypeOptions): - custom_option = None - custom_fields = None - - -class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): - class Meta: - abstract = True - - @classmethod - def __init_subclass_with_meta__( - cls, custom_option=None, custom_fields=None, **options - ): - _meta = CustomOptions(cls) - _meta.custom_option = custom_option - _meta.fields = custom_fields - super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( - _meta=_meta, **options - ) - - -class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): - class Meta: - model = Reporter - custom_option = "custom_option" - custom_fields = OrderedDict([("custom_field", Field(Int()))]) - - def test_objecttype_with_custom_options(): + class CustomOptions(SQLAlchemyObjectTypeOptions): + custom_option = None + + class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__(cls, custom_option=None, **options): + _meta = CustomOptions(cls) + _meta.custom_option = custom_option + super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): + class Meta: + model = Reporter + custom_option = "custom_option" + assert issubclass(ReporterWithCustomOptions, ObjectType) assert ReporterWithCustomOptions._meta.model == Reporter - assert list(ReporterWithCustomOptions._meta.fields.keys()) == [ - "custom_field", - "id", - "first_name", - "last_name", - "email", - "pets", - "articles", - "favorite_article", - ] assert ReporterWithCustomOptions._meta.custom_option == "custom_option" - assert isinstance(ReporterWithCustomOptions._meta.fields["custom_field"].type, Int) - - -def test_promise_connection_resolver(): - class TestConnection(Connection): - class Meta: - node = ReporterWithCustomOptions - - def resolver(*args, **kwargs): - return Promise.resolve([]) - - result = SQLAlchemyConnectionField.connection_resolver( - resolver, TestConnection, ReporterWithCustomOptions, None, None - ) - assert result is not None # Tests for connection_field_factory @@ -197,83 +421,74 @@ def test_default_connection_field_factory(): - _registry = Registry() - - class ReporterType(SQLAlchemyObjectType): - class Meta: - model = Reporter - registry = _registry + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter interfaces = (Node,) class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - registry = _registry interfaces = (Node,) assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) -def test_register_connection_field_factory(): +def test_custom_connection_field_factory(): def test_connection_field_factory(relationship, registry): model = relationship.mapper.entity _type = registry.get_type_for_model(model) return _TestSQLAlchemyConnectionField(_type._meta.connection) - _registry = Registry() - - class ReporterType(SQLAlchemyObjectType): - class Meta: - model = Reporter - registry = _registry + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter interfaces = (Node,) connection_field_factory = test_connection_field_factory class ArticleType(SQLAlchemyObjectType): class Meta: model = Article - registry = _registry interfaces = (Node,) assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) def test_deprecated_registerConnectionFieldFactory(): - registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) - - _registry = Registry() - - class ReporterType(SQLAlchemyObjectType): - class Meta: - model = Reporter - registry = _registry - interfaces = (Node,) - - class ArticleType(SQLAlchemyObjectType): - class Meta: - model = Article - registry = _registry - interfaces = (Node,) - - assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + with pytest.warns(DeprecationWarning): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) def test_deprecated_unregisterConnectionFieldFactory(): - registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) - unregisterConnectionFieldFactory() - - _registry = Registry() - - class ReporterType(SQLAlchemyObjectType): - class Meta: - model = Reporter - registry = _registry - interfaces = (Node,) - - class ArticleType(SQLAlchemyObjectType): - class Meta: - model = Article - registry = _registry - interfaces = (Node,) - - assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + with pytest.warns(DeprecationWarning): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + unregisterConnectionFieldFactory() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + + assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + + +def test_deprecated_createConnectionField(): + with pytest.warns(DeprecationWarning): + createConnectionField(None) diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index a7b902f..e13d919 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -1,9 +1,11 @@ +import pytest import sqlalchemy as sa from graphene import Enum, List, ObjectType, Schema, String -from ..utils import get_session, sort_argument_for_model, sort_enum_for_model -from .models import Editor, Pet +from ..utils import (get_session, sort_argument_for_model, sort_enum_for_model, + to_enum_value_name, to_type_name) +from .models import Base, Editor, Pet def test_get_session(): @@ -27,8 +29,25 @@ assert result.data["x"] == session +def test_to_type_name(): + assert to_type_name("make_camel_case") == "MakeCamelCase" + assert to_type_name("AlreadyCamelCase") == "AlreadyCamelCase" + assert to_type_name("A_Snake_and_a_Camel") == "ASnakeAndACamel" + + +def test_to_enum_value_name(): + assert to_enum_value_name("make_enum_value_name") == "MAKE_ENUM_VALUE_NAME" + assert to_enum_value_name("makeEnumValueName") == "MAKE_ENUM_VALUE_NAME" + assert to_enum_value_name("HTTPStatus400Message") == "HTTP_STATUS400_MESSAGE" + assert to_enum_value_name("ALREADY_ENUM_VALUE_NAME") == "ALREADY_ENUM_VALUE_NAME" + + +# test deprecated sort enum utility functions + + def test_sort_enum_for_model(): - enum = sort_enum_for_model(Pet) + with pytest.warns(DeprecationWarning): + enum = sort_enum_for_model(Pet) assert isinstance(enum, type(Enum)) assert str(enum) == "PetSortEnum" for col in sa.inspect(Pet).columns: @@ -37,7 +56,10 @@ def test_sort_enum_for_model_custom_naming(): - enum = sort_enum_for_model(Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D")) + with pytest.warns(DeprecationWarning): + enum = sort_enum_for_model( + Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D") + ) assert str(enum) == "Foo" for col in sa.inspect(Pet).columns: assert hasattr(enum, col.name.upper() + "A") @@ -45,32 +67,35 @@ def test_enum_cache(): - assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor) + with pytest.warns(DeprecationWarning): + assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor) def test_sort_argument_for_model(): - arg = sort_argument_for_model(Pet) + with pytest.warns(DeprecationWarning): + arg = sort_argument_for_model(Pet) assert isinstance(arg.type, List) assert arg.default_value == [Pet.id.name + "_asc"] - assert arg.type.of_type == sort_enum_for_model(Pet) + with pytest.warns(DeprecationWarning): + assert arg.type.of_type is sort_enum_for_model(Pet) def test_sort_argument_for_model_no_default(): - arg = sort_argument_for_model(Pet, False) + with pytest.warns(DeprecationWarning): + arg = sort_argument_for_model(Pet, False) assert arg.default_value is None def test_sort_argument_for_model_multiple_pk(): - Base = sa.ext.declarative.declarative_base() - class MultiplePK(Base): foo = sa.Column(sa.Integer, primary_key=True) bar = sa.Column(sa.Integer, primary_key=True) __tablename__ = "MultiplePK" - arg = sort_argument_for_model(MultiplePK) + with pytest.warns(DeprecationWarning): + arg = sort_argument_for_model(MultiplePK) assert set(arg.default_value) == set( (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") ) diff --git a/graphene_sqlalchemy/tests/utils.py b/graphene_sqlalchemy/tests/utils.py new file mode 100644 index 0000000..428757c --- /dev/null +++ b/graphene_sqlalchemy/tests/utils.py @@ -0,0 +1,16 @@ +import pkg_resources + + +def to_std_dicts(value): + """Convert nested ordered dicts to normal dicts for better comparison.""" + if isinstance(value, dict): + return {k: to_std_dicts(v) for k, v in value.items()} + elif isinstance(value, list): + return [to_std_dicts(v) for v in value] + else: + return value + + +def is_sqlalchemy_version_less_than(version_string): + """Check the installed SQLAlchemy version""" + return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 394d506..ff22cde 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -2,79 +2,181 @@ import sqlalchemy from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.inspection import inspect as sqlalchemyinspect +from sqlalchemy.orm import (ColumnProperty, CompositeProperty, + RelationshipProperty) from sqlalchemy.orm.exc import NoResultFound -from graphene import Field # , annotate, ResolveInfo +from graphene import Field from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs +from graphene.utils.orderedtype import OrderedType from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_hybrid_method, convert_sqlalchemy_relationship) -from .fields import default_connection_field_factory +from .enums import (enum_for_field, sort_argument_for_object_type, + sort_enum_for_object_type) from .registry import Registry, get_global_registry +from .resolvers import get_attr_resolver, get_custom_resolver from .utils import get_query, is_mapped_class, is_mapped_instance -def construct_fields(model, registry, only_fields, exclude_fields, connection_field_factory): - inspected_model = sqlalchemyinspect(model) - +class ORMField(OrderedType): + def __init__( + self, + model_attr=None, + type=None, + required=None, + description=None, + deprecation_reason=None, + batching=None, + _creation_counter=None, + **field_kwargs + ): + """ + Use this to override fields automatically generated by SQLAlchemyObjectType. + Unless specified, options will default to SQLAlchemyObjectType usual behavior + for the given SQLAlchemy model property. + + Usage: + class MyModel(Base): + id = Column(Integer(), primary_key=True) + name = Column(String) + + class MyType(SQLAlchemyObjectType): + class Meta: + model = MyModel + + id = ORMField(type=graphene.Int) + name = ORMField(required=True) + + -> MyType.id will be of type Int (vs ID). + -> MyType.name will be of type NonNull(String) (vs String). + + :param str model_attr: + Name of the SQLAlchemy model attribute used to resolve this field. + Default to the name of the attribute referencing the ORMField. + :param type: + Default to the type mapping in converter.py. + :param str description: + Default to the `doc` attribute of the SQLAlchemy column property. + :param bool required: + Default to the opposite of the `nullable` attribute of the SQLAlchemy column property. + :param str description: + Same behavior as in graphene.Field. Defaults to None. + :param str deprecation_reason: + Same behavior as in graphene.Field. Defaults to None. + :param bool batching: + Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`. + :param int _creation_counter: + Same behavior as in graphene.Field. + """ + super(ORMField, self).__init__(_creation_counter=_creation_counter) + # The is only useful for documentation and auto-completion + common_kwargs = { + 'model_attr': model_attr, + 'type': type, + 'required': required, + 'description': description, + 'deprecation_reason': deprecation_reason, + 'batching': batching, + } + common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None} + self.kwargs = field_kwargs + self.kwargs.update(common_kwargs) + + +def construct_fields( + obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory +): + """ + Construct all the fields for a SQLAlchemyObjectType. + The main steps are: + - Gather all the relevant attributes from the SQLAlchemy model + - Gather all the ORM fields defined on the type + - Merge in overrides and build up all the fields + + :param SQLAlchemyObjectType obj_type: + :param model: the SQLAlchemy model + :param Registry registry: + :param tuple[string] only_fields: + :param tuple[string] exclude_fields: + :param bool batching: + :param function|None connection_field_factory: + :rtype: OrderedDict[str, graphene.Field] + """ + inspected_model = sqlalchemy.inspect(model) + # Gather all the relevant attributes from the SQLAlchemy model in order + all_model_attrs = OrderedDict( + inspected_model.column_attrs.items() + + inspected_model.composites.items() + + [(name, item) for name, item in inspected_model.all_orm_descriptors.items() + if isinstance(item, hybrid_property)] + + inspected_model.relationships.items() + ) + + # Filter out excluded fields + auto_orm_field_names = [] + for attr_name, attr in all_model_attrs.items(): + if (only_fields and attr_name not in only_fields) or (attr_name in exclude_fields): + continue + auto_orm_field_names.append(attr_name) + + # Gather all the ORM fields defined on the type + custom_orm_fields_items = [ + (attn_name, attr) + for base in reversed(obj_type.__mro__) + for attn_name, attr in base.__dict__.items() + if isinstance(attr, ORMField) + ] + custom_orm_fields_items = sorted(custom_orm_fields_items, key=lambda item: item[1]) + + # Set the model_attr if not set + for orm_field_name, orm_field in custom_orm_fields_items: + attr_name = orm_field.kwargs.get('model_attr', orm_field_name) + if attr_name not in all_model_attrs: + raise ValueError(( + "Cannot map ORMField to a model attribute.\n" + "Field: '{}.{}'" + ).format(obj_type.__name__, orm_field_name,)) + orm_field.kwargs['model_attr'] = attr_name + + # Merge automatic fields with custom ORM fields + orm_fields = OrderedDict(custom_orm_fields_items) + for orm_field_name in auto_orm_field_names: + if orm_field_name in orm_fields: + continue + orm_fields[orm_field_name] = ORMField(model_attr=orm_field_name) + + # Build all the field dictionary fields = OrderedDict() - - for name, column in inspected_model.columns.items(): - is_not_in_only = only_fields and name not in only_fields - # is_already_created = name in options.fields - is_excluded = name in exclude_fields # or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - converted_column = convert_sqlalchemy_column(column, registry) - fields[name] = converted_column - - for name, composite in inspected_model.composites.items(): - is_not_in_only = only_fields and name not in only_fields - # is_already_created = name in options.fields - is_excluded = name in exclude_fields # or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - converted_composite = convert_sqlalchemy_composite(composite, registry) - fields[name] = converted_composite - - for hybrid_item in inspected_model.all_orm_descriptors: - - if type(hybrid_item) == hybrid_property: - name = hybrid_item.__name__ - - is_not_in_only = only_fields and name not in only_fields - # is_already_created = name in options.fields - is_excluded = name in exclude_fields # or is_already_created - - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - - converted_hybrid_property = convert_sqlalchemy_hybrid_method(hybrid_item) - fields[name] = converted_hybrid_property - - # Get all the columns for the relationships on the model - for relationship in inspected_model.relationships: - is_not_in_only = only_fields and relationship.key not in only_fields - # is_already_created = relationship.key in options.fields - is_excluded = relationship.key in exclude_fields # or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we exclude this field in exclude_fields - continue - converted_relationship = convert_sqlalchemy_relationship(relationship, registry, connection_field_factory) - name = relationship.key - fields[name] = converted_relationship + for orm_field_name, orm_field in orm_fields.items(): + attr_name = orm_field.kwargs.pop('model_attr') + attr = all_model_attrs[attr_name] + resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver(obj_type, attr_name) + + if isinstance(attr, ColumnProperty): + field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs) + elif isinstance(attr, RelationshipProperty): + batching_ = orm_field.kwargs.pop('batching', batching) + field = convert_sqlalchemy_relationship( + attr, obj_type, connection_field_factory, batching_, orm_field_name, **orm_field.kwargs) + elif isinstance(attr, CompositeProperty): + if attr_name != orm_field_name or orm_field.kwargs: + # TODO Add a way to override composite property fields + raise ValueError( + "ORMField kwargs for composite fields must be empty. " + "Field: {}.{}".format(obj_type.__name__, orm_field_name)) + field = convert_sqlalchemy_composite(attr, registry, resolver) + elif isinstance(attr, hybrid_property): + field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) + else: + raise Exception('Property type is not supported') # Should never happen + + registry.register_orm_field(obj_type, orm_field_name, attr) + fields[orm_field_name] = field return fields @@ -100,7 +202,8 @@ use_connection=None, interfaces=(), id=None, - connection_field_factory=default_connection_field_factory, + batching=False, + connection_field_factory=None, _meta=None, **options ): @@ -116,15 +219,21 @@ 'Registry, received "{}".' ).format(cls.__name__, registry) + if only_fields and exclude_fields: + raise ValueError("The options 'only_fields' and 'exclude_fields' cannot be both set on the same type.") + sqla_fields = yank_fields_from_attrs( construct_fields( + obj_type=cls, model=model, registry=registry, only_fields=only_fields, exclude_fields=exclude_fields, - connection_field_factory=connection_field_factory + batching=batching, + connection_field_factory=connection_field_factory, ), - _as=Field + _as=Field, + sort=False, ) if use_connection is None and interfaces: @@ -159,6 +268,8 @@ _meta.connection = connection _meta.id = id or "id" + + cls.connection = connection # Public way to get the connection super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( _meta=_meta, interfaces=interfaces, **options @@ -191,3 +302,11 @@ # graphene_type = info.parent_type.graphene_type keys = self.__mapper__.primary_key_from_instance(self) return tuple(keys) if len(keys) > 1 else keys[0] + + @classmethod + def enum_for_field(cls, field_name): + return enum_for_field(cls, field_name) + + sort_enum = classmethod(sort_enum_for_object_type) + + sort_argument = classmethod(sort_argument_for_object_type) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index 276a807..7139eef 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,9 +1,9 @@ +import re +import warnings + from sqlalchemy.exc import ArgumentError -from sqlalchemy.inspection import inspect from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError - -from graphene import Argument, Enum, List def get_session(context): @@ -41,70 +41,102 @@ return True -def _symbol_name(column_name, is_asc): - return column_name + ("_asc" if is_asc else "_desc") +def to_type_name(name): + """Convert the given name to a GraphQL type name.""" + return "".join(part[:1].upper() + part[1:] for part in name.split("_")) + + +_re_enum_value_name_1 = re.compile("(.)([A-Z][a-z]+)") +_re_enum_value_name_2 = re.compile("([a-z0-9])([A-Z])") + + +def to_enum_value_name(name): + """Convert the given name to a GraphQL enum value name.""" + return _re_enum_value_name_2.sub( + r"\1_\2", _re_enum_value_name_1.sub(r"\1_\2", name) + ).upper() class EnumValue(str): - """Subclass of str that stores a string and an arbitrary value in the "value" property""" + """String that has an additional value attached. - def __new__(cls, str_value, value): - return super(EnumValue, cls).__new__(cls, str_value) + This is used to attach SQLAlchemy model columns to Enum symbols. + """ - def __init__(self, str_value, value): + def __new__(cls, s, value): + return super(EnumValue, cls).__new__(cls, s) + + def __init__(self, _s, value): super(EnumValue, self).__init__() self.value = value -# Cache for the generated enums, to avoid name clash -_ENUM_CACHE = {} +def _deprecated_default_symbol_name(column_name, sort_asc): + return column_name + ("_asc" if sort_asc else "_desc") -def _sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): - name = name or cls.__name__ + "SortEnum" - if name in _ENUM_CACHE: - return _ENUM_CACHE[name] - items = [] - default = [] - for column in inspect(cls).columns.values(): - asc_name = symbol_name(column.name, True) - asc_value = EnumValue(asc_name, column.asc()) - desc_name = symbol_name(column.name, False) - desc_value = EnumValue(desc_name, column.desc()) - if column.primary_key: - default.append(asc_value) - items.extend(((asc_name, asc_value), (desc_name, desc_value))) - enum = Enum(name, items) - _ENUM_CACHE[name] = (enum, default) - return enum, default +# unfortunately, we cannot use lru_cache because we still support Python 2 +_deprecated_object_type_cache = {} -def sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): - """Create Graphene Enum for sorting a SQLAlchemy class query +def _deprecated_object_type_for_model(cls, name): - Parameters - - cls : Sqlalchemy model class - Model used to create the sort enumerator - - name : str, optional, default None - Name to use for the enumerator. If not provided it will be set to `cls.__name__ + 'SortEnum'` - - symbol_name : function, optional, default `_symbol_name` - Function which takes the column name and a boolean indicating if the sort direction is ascending, - and returns the symbol name for the current column and sort direction. - The default function will create, for a column named 'foo', the symbols 'foo_asc' and 'foo_desc' + try: + return _deprecated_object_type_cache[cls, name] + except KeyError: + from .types import SQLAlchemyObjectType - Returns - - Enum - The Graphene enumerator + obj_type_name = name or cls.__name__ + + class ObjType(SQLAlchemyObjectType): + class Meta: + name = obj_type_name + model = cls + + _deprecated_object_type_cache[cls, name] = ObjType + return ObjType + + +def sort_enum_for_model(cls, name=None, symbol_name=None): + """Get a Graphene Enum for sorting the given model class. + + This is deprecated, please use object_type.sort_enum() instead. """ - enum, _ = _sort_enum_for_model(cls, name, symbol_name) - return enum + warnings.warn( + "sort_enum_for_model() is deprecated; use object_type.sort_enum() instead.", + DeprecationWarning, + stacklevel=2, + ) + + from .enums import sort_enum_for_object_type + + return sort_enum_for_object_type( + _deprecated_object_type_for_model(cls, name), + name, + get_symbol_name=symbol_name or _deprecated_default_symbol_name, + ) def sort_argument_for_model(cls, has_default=True): - """Returns a Graphene argument for the sort field that accepts a list of sorting directions for a model. - If `has_default` is True (the default) it will sort the result by the primary key(s) + """Get a Graphene Argument for sorting the given model class. + + This is deprecated, please use object_type.sort_argument() instead. """ - enum, default = _sort_enum_for_model(cls) + warnings.warn( + "sort_argument_for_model() is deprecated;" + " use object_type.sort_argument() instead.", + DeprecationWarning, + stacklevel=2, + ) + + from graphene import Argument, List + from .enums import sort_enum_for_object_type + + enum = sort_enum_for_object_type( + _deprecated_object_type_for_model(cls, None), + get_symbol_name=_deprecated_default_symbol_name, + ) if not has_default: - default = None - return Argument(List(enum), default_value=default) + enum.default = None + + return Argument(List(enum), default_value=enum.default) diff --git a/setup.cfg b/setup.cfg index 7fd23df..4e8e502 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,11 +6,12 @@ max-line-length = 120 [isort] +no_lines_before=FIRSTPARTY known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme known_first_party=graphene_sqlalchemy -known_third_party=flask,nameko,promise,py,pytest,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +known_third_party=app,database,flask,graphql,mock,models,nameko,pkg_resources,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER -no_lines_before=FIRSTPARTY +skip_glob=examples/nameko_sqlalchemy [bdist_wheel] universal=1 diff --git a/setup.py b/setup.py index 66704b2..7b350c3 100644 --- a/setup.py +++ b/setup.py @@ -14,8 +14,9 @@ requirements = [ # To keep things simple, we only support newer versions of Graphene "graphene>=2.1.3,<3", + "promise>=2.3", # Tests fail with 1.0.19 - "SQLAlchemy>=1.1,<2", + "SQLAlchemy>=1.2,<2", "six>=1.10.0,<2", "singledispatch>=3.4.0.3,<4", ] @@ -29,6 +30,7 @@ "mock==2.0.0", "pytest-cov==2.6.1", "sqlalchemy_utils==0.33.9", + "pytest-benchmark==3.2.1", ] setup( @@ -47,8 +49,6 @@ "Programming Language :: Python :: 2", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.3", - "Programming Language :: Python :: 3.4", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", @@ -60,7 +60,7 @@ extras_require={ "dev": [ "tox==3.7.0", # Should be kept in sync with tox.ini - "coveralls==1.7.0", + "coveralls==1.10.0", "pre-commit==1.14.4", ], "test": tests_require, diff --git a/tox.ini b/tox.ini index e55f7d9..562da2d 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = pre-commit,py{27,34,35,36,37}-sql{11,12,13} +envlist = pre-commit,py{27,35,36,37}-sql{11,12,13} skipsdist = true minversion = 3.7.0