Codebase list python-graphene-sqlalchemy / 818eac4
Import upstream version 2.3.0 Kali Janitor 3 years ago
47 changed file(s) with 3598 addition(s) and 1174 deletion(s). Raw diff Collapse all Expand all
1010 # Distribution / packaging
1111 .Python
1212 env/
13 .venv/
1314 build/
1415 develop-eggs/
1516 dist/
2425 *.egg-info/
2526 .installed.cfg
2627 *.egg
28 .python-version
2729
2830 # PyInstaller
2931 # Usually these files are written by a python script from a template
4547 coverage.xml
4648 *,cover
4749 .pytest_cache/
50 .benchmarks/
4851
4952 # Translations
5053 *.mo
33 # Python 2.7
44 - env: TOXENV=py27
55 python: 2.7
6 # Python 3.5
7 - env: TOXENV=py34
8 python: 3.4
96 # Python 3.5
107 - env: TOXENV=py35
118 python: 3.5
0 / @cito @jnak @Nabellaleen
4242 class User(SQLAlchemyObjectType):
4343 class Meta:
4444 model = UserModel
45 # only return specified fields
46 only_fields = ("name",)
47 # exclude specified fields
48 exclude_fields = ("last_name",)
45 # use `only_fields` to only expose specific fields ie "name"
46 # only_fields = ("name",)
47 # use `exclude_fields` to exclude specific fields ie "last_name"
48 # exclude_fields = ("last_name",)
4949
5050 class Query(graphene.ObjectType):
5151 users = graphene.List(User)
1212 interfaces = (relay.Node,)
1313
1414
15 class BookConnection(relay.Connection):
16 class Meta:
17 node = Book
18
19
2015 class Author(SQLAlchemyObjectType):
2116 class Meta:
2217 model = AuthorModel
2318 interfaces = (relay.Node,)
24
25
26 class AuthorConnection(relay.Connection):
27 class Meta:
28 node = Author
2919
3020
3121 class SearchResult(graphene.Union):
3828 search = graphene.List(SearchResult, q=graphene.String()) # List field for search results
3929
4030 # Normal Fields
41 all_books = SQLAlchemyConnectionField(BookConnection)
42 all_authors = SQLAlchemyConnectionField(AuthorConnection)
31 all_books = SQLAlchemyConnectionField(Book.connection)
32 all_authors = SQLAlchemyConnectionField(Author.connection)
4333
4434 def resolve_search(self, info, **args):
4535 q = args.get("q") # Search query
4949 model = Pet
5050
5151
52 class PetConnection(Connection):
53 class Meta:
54 node = PetNode
55
56
5752 class Query(ObjectType):
58 allPets = SQLAlchemyConnectionField(PetConnection)
53 allPets = SQLAlchemyConnectionField(PetNode.connection)
5954
6055 some of the allowed queries are
6156
101101 interfaces = (relay.Node, )
102102
103103
104 class DepartmentConnection(relay.Connection):
105 class Meta:
106 node = Department
107
108
109104 class Employee(SQLAlchemyObjectType):
110105 class Meta:
111106 model = EmployeeModel
112107 interfaces = (relay.Node, )
113108
114109
115 class EmployeeConnection(relay.Connection):
116 class Meta:
117 node = Employee
118
119
120110 class Query(graphene.ObjectType):
121111 node = relay.Node.Field()
122112 # Allows sorting over multiple columns, by default over the primary key
123 all_employees = SQLAlchemyConnectionField(EmployeeConnection)
113 all_employees = SQLAlchemyConnectionField(Employee.connection)
124114 # Disable sorting over this field
125 all_departments = SQLAlchemyConnectionField(DepartmentConnection, sort=None)
115 all_departments = SQLAlchemyConnectionField(Department.connection, sort=None)
126116
127117 schema = graphene.Schema(query=Query)
128118
88 ---------------
99
1010 First you'll need to get the source of the project. Do this by cloning the
11 whole Graphene repository:
11 whole Graphene-SQLAlchemy repository:
1212
1313 ```bash
1414 # Get the example project code
00 #!/usr/bin/env python
11
2 from database import db_session, init_db
23 from flask import Flask
4 from schema import schema
35
46 from flask_graphql import GraphQLView
5
6 from .database import db_session, init_db
7 from .schema import schema
87
98 app = Flask(__name__)
109 app.debug = True
1110
12 default_query = '''
11 example_query = """
1312 {
14 allEmployees {
13 allEmployees(sort: [NAME_ASC, ID_ASC]) {
1514 edges {
1615 node {
17 id,
18 name,
16 id
17 name
1918 department {
20 id,
19 id
2120 name
22 },
21 }
2322 role {
24 id,
23 id
2524 name
2625 }
2726 }
2827 }
2928 }
30 }'''.strip()
29 }
30 """
3131
3232
33 app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=True))
33 app.add_url_rule(
34 "/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=True)
35 )
3436
3537
3638 @app.teardown_appcontext
3739 def shutdown_session(exception=None):
3840 db_session.remove()
3941
40 if __name__ == '__main__':
42
43 if __name__ == "__main__":
4144 init_db()
4245 app.run()
1313 # import all modules here that might define models so that
1414 # they will be registered properly on the metadata. Otherwise
1515 # you will have to import them first before calling init_db()
16 from .models import Department, Employee, Role
16 from models import Department, Employee, Role
1717 Base.metadata.drop_all(bind=engine)
1818 Base.metadata.create_all(bind=engine)
1919
0 from database import Base
01 from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func
12 from sqlalchemy.orm import backref, relationship
2
3 from .database import Base
43
54
65 class Department(Base):
0 graphene[sqlalchemy]
1 SQLAlchemy==1.0.11
2 Flask==0.12.4
3 Flask-GraphQL==1.3.0
0 -e ../../
1 Flask-GraphQL
0 from models import Department as DepartmentModel
1 from models import Employee as EmployeeModel
2 from models import Role as RoleModel
3
04 import graphene
15 from graphene import relay
2 from graphene_sqlalchemy import (SQLAlchemyConnectionField,
3 SQLAlchemyObjectType, utils)
4
5 from .models import Department as DepartmentModel
6 from .models import Employee as EmployeeModel
7 from .models import Role as RoleModel
6 from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType
87
98
109 class Department(SQLAlchemyObjectType):
2524 interfaces = (relay.Node, )
2625
2726
28 SortEnumEmployee = utils.sort_enum_for_model(EmployeeModel, 'SortEnumEmployee',
29 lambda c, d: c.upper() + ('_ASC' if d else '_DESC'))
30
31
3227 class Query(graphene.ObjectType):
3328 node = relay.Node.Field()
3429 # Allow only single column sorting
3530 all_employees = SQLAlchemyConnectionField(
36 Employee,
37 sort=graphene.Argument(
38 SortEnumEmployee,
39 default_value=utils.EnumValue('id_asc', EmployeeModel.id.asc())))
31 Employee.connection, sort=Employee.sort_argument())
4032 # Allows sorting over multiple columns, by default over the primary key
41 all_roles = SQLAlchemyConnectionField(Role)
33 all_roles = SQLAlchemyConnectionField(Role.connection)
4234 # Disable sorting over this field
43 all_departments = SQLAlchemyConnectionField(Department, sort=None)
35 all_departments = SQLAlchemyConnectionField(Department.connection, sort=None)
4436
4537
46 schema = graphene.Schema(query=Query, types=[Department, Employee, Role])
38 schema = graphene.Schema(query=Query)
1313 ---------------
1414
1515 First you'll need to get the source of the project. Do this by cloning the
16 whole Graphene repository:
16 whole Graphene-SQLAlchemy repository:
1717
1818 ```bash
1919 # Get the example project code
4545
4646 ```bash
4747 ./run.sh
48
4948 ```
5049
5150 Now head on over to postman and send POST request to:
0 from database import db_session, init_db
1 from schema import schema
2
03 from graphql_server import (HttpQueryError, default_format_error,
14 encode_execution_results, json_encode,
25 load_json_body, run_http_query)
3
4 from .database import db_session, init_db
5 from .schema import schema
66
77
88 class App():
1313 # import all modules here that might define models so that
1414 # they will be registered properly on the metadata. Otherwise
1515 # you will have to import them first before calling init_db()
16 from .models import Department, Employee, Role
16 from models import Department, Employee, Role
1717 Base.metadata.drop_all(bind=engine)
1818 Base.metadata.create_all(bind=engine)
1919
0 from database import Base
01 from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func
12 from sqlalchemy.orm import backref, relationship
2
3 from .database import Base
43
54
65 class Department(Base):
0 graphene[sqlalchemy]
1 SQLAlchemy==1.0.11
0 -e ../../
1 graphql-server-core
22 nameko
3 graphql-server-core
0 from models import Department as DepartmentModel
1 from models import Employee as EmployeeModel
2 from models import Role as RoleModel
3
04 import graphene
15 from graphene import relay
26 from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType
37
4 from .models import Department as DepartmentModel
5 from .models import Employee as EmployeeModel
6 from .models import Role as RoleModel
7
88
99 class Department(SQLAlchemyObjectType):
10
1110 class Meta:
1211 model = DepartmentModel
13 interfaces = (relay.Node, )
12 interfaces = (relay.Node,)
1413
1514
1615 class Employee(SQLAlchemyObjectType):
17
1816 class Meta:
1917 model = EmployeeModel
20 interfaces = (relay.Node, )
18 interfaces = (relay.Node,)
2119
2220
2321 class Role(SQLAlchemyObjectType):
24
2522 class Meta:
2623 model = RoleModel
27 interfaces = (relay.Node, )
24 interfaces = (relay.Node,)
2825
2926
3027 class Query(graphene.ObjectType):
3128 node = relay.Node.Field()
32 all_employees = SQLAlchemyConnectionField(Employee)
33 all_roles = SQLAlchemyConnectionField(Role)
29 all_employees = SQLAlchemyConnectionField(Employee.connection)
30 all_roles = SQLAlchemyConnectionField(Role.connection)
3431 role = graphene.Field(Role)
3532
3633
37 schema = graphene.Schema(query=Query, types=[Department, Employee, Role])
34 schema = graphene.Schema(query=Query)
00 #!/usr/bin/env python
1 from app import App
12 from nameko.web.handlers import http
2
3 from .app import App
43
54
65 class DepartmentService:
11 from .fields import SQLAlchemyConnectionField
22 from .utils import get_query, get_session
33
4 __version__ = "2.1.2"
4 __version__ = "2.3.0"
55
66 __all__ = [
77 "__version__",
0 import sqlalchemy
1 from promise import dataloader, promise
2 from sqlalchemy.orm import Session, strategies
3 from sqlalchemy.orm.query import QueryContext
4
5
6 def get_batch_resolver(relationship_prop):
7
8 # Cache this across `batch_load_fn` calls
9 # This is so SQL string generation is cached under-the-hood via `bakery`
10 selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))
11
12 class RelationshipLoader(dataloader.DataLoader):
13 cache = False
14
15 def batch_load_fn(self, parents): # pylint: disable=method-hidden
16 """
17 Batch loads the relationships of all the parents as one SQL statement.
18
19 There is no way to do this out-of-the-box with SQLAlchemy but
20 we can piggyback on some internal APIs of the `selectin`
21 eager loading strategy. It's a bit hacky but it's preferable
22 than re-implementing and maintainnig a big chunk of the `selectin`
23 loader logic ourselves.
24
25 The approach here is to build a regular query that
26 selects the parent and `selectin` load the relationship.
27 But instead of having the query emits 2 `SELECT` statements
28 when callling `all()`, we skip the first `SELECT` statement
29 and jump right before the `selectin` loader is called.
30 To accomplish this, we have to construct objects that are
31 normally built in the first part of the query in order
32 to call directly `SelectInLoader._load_for_path`.
33
34 TODO Move this logic to a util in the SQLAlchemy repo as per
35 SQLAlchemy's main maitainer suggestion.
36 See https://git.io/JewQ7
37 """
38 child_mapper = relationship_prop.mapper
39 parent_mapper = relationship_prop.parent
40 session = Session.object_session(parents[0])
41
42 # These issues are very unlikely to happen in practice...
43 for parent in parents:
44 # assert parent.__mapper__ is parent_mapper
45 # All instances must share the same session
46 assert session is Session.object_session(parent)
47 # The behavior of `selectin` is undefined if the parent is dirty
48 assert parent not in session.dirty
49
50 # Should the boolean be set to False? Does it matter for our purposes?
51 states = [(sqlalchemy.inspect(parent), True) for parent in parents]
52
53 # For our purposes, the query_context will only used to get the session
54 query_context = QueryContext(session.query(parent_mapper.entity))
55
56 selectin_loader._load_for_path(
57 query_context,
58 parent_mapper._path_registry,
59 states,
60 None,
61 child_mapper,
62 )
63
64 return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents])
65
66 loader = RelationshipLoader()
67
68 def resolve(root, info, **args):
69 return loader.load(root)
70
71 return resolve
0 from enum import EnumMeta
1
02 from singledispatch import singledispatch
13 from sqlalchemy import types
24 from sqlalchemy.dialects import postgresql
3 from sqlalchemy.orm import interfaces
5 from sqlalchemy.orm import interfaces, strategies
46
57 from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
68 String)
79 from graphene.types.json import JSONString
10
11 from .batching import get_batch_resolver
12 from .enums import enum_for_sa_enum
13 from .fields import (BatchSQLAlchemyConnectionField,
14 default_connection_field_factory)
15 from .registry import get_global_registry
16 from .resolvers import get_attr_resolver, get_custom_resolver
817
918 try:
1019 from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
1221 ChoiceType = JSONType = ScalarListType = TSVectorType = object
1322
1423
24 is_selectin_available = getattr(strategies, 'SelectInLoader', None)
25
26
1527 def get_column_doc(column):
1628 return getattr(column, "doc", None)
1729
2032 return bool(getattr(column, "nullable", True))
2133
2234
23 def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory):
24 direction = relationship.direction
25 model = relationship.mapper.entity
26
35 def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching,
36 orm_field_name, **field_kwargs):
37 """
38 :param sqlalchemy.RelationshipProperty relationship_prop:
39 :param SQLAlchemyObjectType obj_type:
40 :param function|None connection_field_factory:
41 :param bool batching:
42 :param str orm_field_name:
43 :param dict field_kwargs:
44 :rtype: Dynamic
45 """
2746 def dynamic_type():
28 _type = registry.get_type_for_model(model)
29 if not _type:
47 """:rtype: Field|None"""
48 direction = relationship_prop.direction
49 child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
50 batching_ = batching if is_selectin_available else False
51
52 if not child_type:
3053 return None
31 if direction == interfaces.MANYTOONE or not relationship.uselist:
32 return Field(_type)
33 elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
34 if _type._meta.connection:
35 return connection_field_factory(relationship, registry)
36 return Field(List(_type))
54
55 if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
56 return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name,
57 **field_kwargs)
58
59 if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
60 return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_,
61 connection_field_factory, **field_kwargs)
3762
3863 return Dynamic(dynamic_type)
3964
4065
41 def convert_sqlalchemy_hybrid_method(hybrid_item):
42 return String(description=getattr(hybrid_item, "__doc__", None), required=False)
43
44
45 def convert_sqlalchemy_composite(composite, registry):
46 converter = registry.get_converter_for_composite(composite.composite_class)
66 def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs):
67 """
68 Convert one-to-one or many-to-one relationshsip. Return an object field.
69
70 :param sqlalchemy.RelationshipProperty relationship_prop:
71 :param SQLAlchemyObjectType obj_type:
72 :param bool batching:
73 :param str orm_field_name:
74 :param dict field_kwargs:
75 :rtype: Field
76 """
77 child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
78
79 resolver = get_custom_resolver(obj_type, orm_field_name)
80 if resolver is None:
81 resolver = get_batch_resolver(relationship_prop) if batching else \
82 get_attr_resolver(obj_type, relationship_prop.key)
83
84 return Field(child_type, resolver=resolver, **field_kwargs)
85
86
87 def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs):
88 """
89 Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field.
90
91 :param sqlalchemy.RelationshipProperty relationship_prop:
92 :param SQLAlchemyObjectType obj_type:
93 :param bool batching:
94 :param function|None connection_field_factory:
95 :param dict field_kwargs:
96 :rtype: Field
97 """
98 child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
99
100 if not child_type._meta.connection:
101 return Field(List(child_type), **field_kwargs)
102
103 # TODO Allow override of connection_field_factory and resolver via ORMField
104 if connection_field_factory is None:
105 connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \
106 default_connection_field_factory
107
108 return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs)
109
110
111 def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs):
112 if 'type' not in field_kwargs:
113 # TODO The default type should be dependent on the type of the property propety.
114 field_kwargs['type'] = String
115
116 return Field(
117 resolver=resolver,
118 **field_kwargs
119 )
120
121
122 def convert_sqlalchemy_composite(composite_prop, registry, resolver):
123 converter = registry.get_converter_for_composite(composite_prop.composite_class)
47124 if not converter:
48125 try:
49126 raise Exception(
50127 "Don't know how to convert the composite field %s (%s)"
51 % (composite, composite.composite_class)
128 % (composite_prop, composite_prop.composite_class)
52129 )
53130 except AttributeError:
54131 # handle fields that are not attached to a class yet (don't have a parent)
55132 raise Exception(
56133 "Don't know how to convert the composite field %r (%s)"
57 % (composite, composite.composite_class)
134 % (composite_prop, composite_prop.composite_class)
58135 )
59 return converter(composite, registry)
136
137 # TODO Add a way to override composite fields default parameters
138 return converter(composite_prop, registry)
60139
61140
62141 def _register_composite_class(cls, registry=None):
74153 convert_sqlalchemy_composite.register = _register_composite_class
75154
76155
77 def convert_sqlalchemy_column(column, registry=None):
78 return convert_sqlalchemy_type(getattr(column, "type", None), column, registry)
156 def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs):
157 column = column_prop.columns[0]
158 field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
159 field_kwargs.setdefault('required', not is_column_nullable(column))
160 field_kwargs.setdefault('description', get_column_doc(column))
161
162 return Field(
163 resolver=resolver,
164 **field_kwargs
165 )
79166
80167
81168 @singledispatch
97184 @convert_sqlalchemy_type.register(postgresql.CIDR)
98185 @convert_sqlalchemy_type.register(TSVectorType)
99186 def convert_column_to_string(type, column, registry=None):
100 return String(
101 description=get_column_doc(column), required=not (is_column_nullable(column))
102 )
187 return String
103188
104189
105190 @convert_sqlalchemy_type.register(types.DateTime)
106191 def convert_column_to_datetime(type, column, registry=None):
107192 from graphene.types.datetime import DateTime
108
109 return DateTime(
110 description=get_column_doc(column), required=not (is_column_nullable(column))
111 )
193 return DateTime
112194
113195
114196 @convert_sqlalchemy_type.register(types.SmallInteger)
115197 @convert_sqlalchemy_type.register(types.Integer)
116198 def convert_column_to_int_or_id(type, column, registry=None):
117 if column.primary_key:
118 return ID(
119 description=get_column_doc(column),
120 required=not (is_column_nullable(column)),
121 )
122 else:
123 return Int(
124 description=get_column_doc(column),
125 required=not (is_column_nullable(column)),
126 )
199 return ID if column.primary_key else Int
127200
128201
129202 @convert_sqlalchemy_type.register(types.Boolean)
130203 def convert_column_to_boolean(type, column, registry=None):
131 return Boolean(
132 description=get_column_doc(column), required=not (is_column_nullable(column))
133 )
204 return Boolean
134205
135206
136207 @convert_sqlalchemy_type.register(types.Float)
137208 @convert_sqlalchemy_type.register(types.Numeric)
138209 @convert_sqlalchemy_type.register(types.BigInteger)
139210 def convert_column_to_float(type, column, registry=None):
140 return Float(
141 description=get_column_doc(column), required=not (is_column_nullable(column))
142 )
211 return Float
143212
144213
145214 @convert_sqlalchemy_type.register(types.Enum)
146215 def convert_enum_to_enum(type, column, registry=None):
147 enum_class = getattr(type, 'enum_class', None)
148 if enum_class: # Check if an enum.Enum type is used
149 graphene_type = Enum.from_enum(enum_class)
150 else: # Nope, just a list of string options
151 items = zip(type.enums, type.enums)
152 graphene_type = Enum(type.name, items)
153 return Field(
154 graphene_type,
155 description=get_column_doc(column),
156 required=not (is_column_nullable(column)),
157 )
158
159
216 return lambda: enum_for_sa_enum(type, registry or get_global_registry())
217
218
219 # TODO Make ChoiceType conversion consistent with other enums
160220 @convert_sqlalchemy_type.register(ChoiceType)
161 def convert_column_to_enum(type, column, registry=None):
221 def convert_choice_to_enum(type, column, registry=None):
162222 name = "{}_{}".format(column.table.name, column.name).upper()
163 return Enum(name, type.choices, description=get_column_doc(column))
223 if isinstance(type.choices, EnumMeta):
224 # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta
225 # do not use from_enum here because we can have more than one enum column in table
226 return Enum(name, list((v.name, v.value) for v in type.choices))
227 else:
228 return Enum(name, type.choices)
164229
165230
166231 @convert_sqlalchemy_type.register(ScalarListType)
167232 def convert_scalar_list_to_list(type, column, registry=None):
168 return List(String, description=get_column_doc(column))
169
170
233 return List(String)
234
235
236 @convert_sqlalchemy_type.register(types.ARRAY)
171237 @convert_sqlalchemy_type.register(postgresql.ARRAY)
172 def convert_postgres_array_to_list(_type, column, registry=None):
173 graphene_type = convert_sqlalchemy_type(column.type.item_type, column)
174 inner_type = type(graphene_type)
175 return List(
176 inner_type,
177 description=get_column_doc(column),
178 required=not (is_column_nullable(column)),
179 )
238 def convert_array_to_list(_type, column, registry=None):
239 inner_type = convert_sqlalchemy_type(column.type.item_type, column)
240 return List(inner_type)
180241
181242
182243 @convert_sqlalchemy_type.register(postgresql.HSTORE)
183244 @convert_sqlalchemy_type.register(postgresql.JSON)
184245 @convert_sqlalchemy_type.register(postgresql.JSONB)
185246 def convert_json_to_string(type, column, registry=None):
186 return JSONString(
187 description=get_column_doc(column), required=not (is_column_nullable(column))
188 )
247 return JSONString
189248
190249
191250 @convert_sqlalchemy_type.register(JSONType)
192251 def convert_json_type_to_string(type, column, registry=None):
193 return JSONString(
194 description=get_column_doc(column), required=not (is_column_nullable(column))
195 )
252 return JSONString
0 import six
1 from sqlalchemy.orm import ColumnProperty
2 from sqlalchemy.types import Enum as SQLAlchemyEnumType
3
4 from graphene import Argument, Enum, List
5
6 from .utils import EnumValue, to_enum_value_name, to_type_name
7
8
9 def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None):
10 """Convert the given SQLAlchemy Enum type to a Graphene Enum type.
11
12 The name of the Graphene Enum will be determined as follows:
13 If the SQLAlchemy Enum is based on a Python Enum, use the name
14 of the Python Enum. Otherwise, if the SQLAlchemy Enum is named,
15 use the SQL name after conversion to a type name. Otherwise, use
16 the given fallback_name or raise an error if it is empty.
17
18 The Enum value names are converted to upper case if necessary.
19 """
20 if not isinstance(sa_enum, SQLAlchemyEnumType):
21 raise TypeError(
22 "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)
23 )
24 enum_class = sa_enum.enum_class
25 if enum_class:
26 if all(to_enum_value_name(key) == key for key in enum_class.__members__):
27 return Enum.from_enum(enum_class)
28 name = enum_class.__name__
29 members = [
30 (to_enum_value_name(key), value.value)
31 for key, value in enum_class.__members__.items()
32 ]
33 else:
34 sql_enum_name = sa_enum.name
35 if sql_enum_name:
36 name = to_type_name(sql_enum_name)
37 elif fallback_name:
38 name = fallback_name
39 else:
40 raise TypeError("No type name specified for {!r}".format(sa_enum))
41 members = [(to_enum_value_name(key), key) for key in sa_enum.enums]
42 return Enum(name, members)
43
44
45 def enum_for_sa_enum(sa_enum, registry):
46 """Return the Graphene Enum type for the specified SQLAlchemy Enum type."""
47 if not isinstance(sa_enum, SQLAlchemyEnumType):
48 raise TypeError(
49 "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum)
50 )
51 enum = registry.get_graphene_enum_for_sa_enum(sa_enum)
52 if not enum:
53 enum = _convert_sa_to_graphene_enum(sa_enum)
54 registry.register_enum(sa_enum, enum)
55 return enum
56
57
58 def enum_for_field(obj_type, field_name):
59 """Return the Graphene Enum type for the specified Graphene field."""
60 from .types import SQLAlchemyObjectType
61
62 if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType):
63 raise TypeError(
64 "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type))
65 if not field_name or not isinstance(field_name, six.string_types):
66 raise TypeError(
67 "Expected a field name, but got: {!r}".format(field_name))
68 registry = obj_type._meta.registry
69 orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
70 if orm_field is None:
71 raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name))
72 if not isinstance(orm_field, ColumnProperty):
73 raise TypeError(
74 "{}.{} does not map to model column".format(obj_type._meta.name, field_name)
75 )
76 column = orm_field.columns[0]
77 sa_enum = column.type
78 if not isinstance(sa_enum, SQLAlchemyEnumType):
79 raise TypeError(
80 "{}.{} does not map to enum column".format(obj_type._meta.name, field_name)
81 )
82 enum = registry.get_graphene_enum_for_sa_enum(sa_enum)
83 if not enum:
84 fallback_name = obj_type._meta.name + to_type_name(field_name)
85 enum = _convert_sa_to_graphene_enum(sa_enum, fallback_name)
86 registry.register_enum(sa_enum, enum)
87 return enum
88
89
90 def _default_sort_enum_symbol_name(column_name, sort_asc=True):
91 return to_enum_value_name(column_name) + ("_ASC" if sort_asc else "_DESC")
92
93
94 def sort_enum_for_object_type(
95 obj_type, name=None, only_fields=None, only_indexed=None, get_symbol_name=None
96 ):
97 """Return Graphene Enum for sorting the given SQLAlchemyObjectType.
98
99 Parameters
100 - obj_type : SQLAlchemyObjectType
101 The object type for which the sort Enum shall be generated.
102 - name : str, optional, default None
103 Name to use for the sort Enum.
104 If not provided, it will be set to the object type name + 'SortEnum'
105 - only_fields : sequence, optional, default None
106 If this is set, only fields from this sequence will be considered.
107 - only_indexed : bool, optional, default False
108 If this is set, only indexed columns will be considered.
109 - get_symbol_name : function, optional, default None
110 Function which takes the column name and a boolean indicating
111 if the sort direction is ascending, and returns the symbol name
112 for the current column and sort direction. If no such function
113 is passed, a default function will be used that creates the symbols
114 'foo_asc' and 'foo_desc' for a column with the name 'foo'.
115
116 Returns
117 - Enum
118 The Graphene Enum type
119 """
120 name = name or obj_type._meta.name + "SortEnum"
121 registry = obj_type._meta.registry
122 enum = registry.get_sort_enum_for_object_type(obj_type)
123 custom_options = dict(
124 only_fields=only_fields,
125 only_indexed=only_indexed,
126 get_symbol_name=get_symbol_name,
127 )
128 if enum:
129 if name != enum.__name__ or custom_options != enum.custom_options:
130 raise ValueError(
131 "Sort enum for {} has already been customized".format(obj_type)
132 )
133 else:
134 members = []
135 default = []
136 fields = obj_type._meta.fields
137 get_name = get_symbol_name or _default_sort_enum_symbol_name
138 for field_name in fields:
139 if only_fields and field_name not in only_fields:
140 continue
141 orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name)
142 if not isinstance(orm_field, ColumnProperty):
143 continue
144 column = orm_field.columns[0]
145 if only_indexed and not (column.primary_key or column.index):
146 continue
147 asc_name = get_name(column.name, True)
148 asc_value = EnumValue(asc_name, column.asc())
149 desc_name = get_name(column.name, False)
150 desc_value = EnumValue(desc_name, column.desc())
151 if column.primary_key:
152 default.append(asc_value)
153 members.extend(((asc_name, asc_value), (desc_name, desc_value)))
154 enum = Enum(name, members)
155 enum.default = default # store default as attribute
156 enum.custom_options = custom_options
157 registry.register_sort_enum(obj_type, enum)
158 return enum
159
160
161 def sort_argument_for_object_type(
162 obj_type,
163 enum_name=None,
164 only_fields=None,
165 only_indexed=None,
166 get_symbol_name=None,
167 has_default=True,
168 ):
169 """"Returns Graphene Argument for sorting the given SQLAlchemyObjectType.
170
171 Parameters
172 - obj_type : SQLAlchemyObjectType
173 The object type for which the sort Argument shall be generated.
174 - enum_name : str, optional, default None
175 Name to use for the sort Enum.
176 If not provided, it will be set to the object type name + 'SortEnum'
177 - only_fields : sequence, optional, default None
178 If this is set, only fields from this sequence will be considered.
179 - only_indexed : bool, optional, default False
180 If this is set, only indexed columns will be considered.
181 - get_symbol_name : function, optional, default None
182 Function which takes the column name and a boolean indicating
183 if the sort direction is ascending, and returns the symbol name
184 for the current column and sort direction. If no such function
185 is passed, a default function will be used that creates the symbols
186 'foo_asc' and 'foo_desc' for a column with the name 'foo'.
187 - has_default : bool, optional, default True
188 If this is set to False, no sorting will happen when this argument is not
189 passed. Otherwise results will be sortied by the primary key(s) of the model.
190
191 Returns
192 - Enum
193 A Graphene Argument that accepts a list of sorting directions for the model.
194 """
195 enum = sort_enum_for_object_type(
196 obj_type,
197 enum_name,
198 only_fields=only_fields,
199 only_indexed=only_indexed,
200 get_symbol_name=get_symbol_name,
201 )
202 if not has_default:
203 enum.default = None
204
205 return Argument(List(enum), default_value=enum.default)
0 import logging
0 import warnings
11 from functools import partial
22
3 import six
34 from promise import Promise, is_thenable
45 from sqlalchemy.orm.query import Query
56
7 from graphene import NonNull
68 from graphene.relay import Connection, ConnectionField
79 from graphene.relay.connection import PageInfo
810 from graphql_relay.connection.arrayconnection import connection_from_list_slice
911
10 from .utils import get_query, sort_argument_for_model
11
12 log = logging.getLogger()
12 from .batching import get_batch_resolver
13 from .utils import get_query
1314
1415
1516 class UnsortedSQLAlchemyConnectionField(ConnectionField):
1819 from .types import SQLAlchemyObjectType
1920
2021 _type = super(ConnectionField, self).type
21 if issubclass(_type, Connection):
22 nullable_type = get_nullable_type(_type)
23 if issubclass(nullable_type, Connection):
2224 return _type
23 assert issubclass(_type, SQLAlchemyObjectType), (
25 assert issubclass(nullable_type, SQLAlchemyObjectType), (
2426 "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}"
25 ).format(_type.__name__)
26 assert _type._meta.connection, "The type {} doesn't have a connection".format(
27 _type.__name__
27 ).format(nullable_type.__name__)
28 assert (
29 nullable_type.connection
30 ), "The type {} doesn't have a connection".format(
31 nullable_type.__name__
2832 )
29 return _type._meta.connection
33 assert _type == nullable_type, (
34 "Passing a SQLAlchemyObjectType instance is deprecated. "
35 "Pass the connection type instead accessible via SQLAlchemyObjectType.connection"
36 )
37 return nullable_type.connection
3038
3139 @property
3240 def model(self):
33 return self.type._meta.node._meta.model
41 return get_nullable_type(self.type)._meta.node._meta.model
3442
3543 @classmethod
36 def get_query(cls, model, info, sort=None, **args):
37 query = get_query(model, info.context)
38 if sort is not None:
39 if isinstance(sort, str):
40 query = query.order_by(sort.value)
41 else:
42 query = query.order_by(*(col.value for col in sort))
43 return query
44 def get_query(cls, model, info, **args):
45 return get_query(model, info.context)
4446
4547 @classmethod
4648 def resolve_connection(cls, connection_type, model, info, args, resolved):
7577 return on_resolve(resolved)
7678
7779 def get_resolver(self, parent_resolver):
78 return partial(self.connection_resolver, parent_resolver, self.type, self.model)
80 return partial(
81 self.connection_resolver,
82 parent_resolver,
83 get_nullable_type(self.type),
84 self.model,
85 )
7986
8087
88 # TODO Rename this to SortableSQLAlchemyConnectionField
8189 class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
8290 def __init__(self, type, *args, **kwargs):
83 if "sort" not in kwargs and issubclass(type, Connection):
91 nullable_type = get_nullable_type(type)
92 if "sort" not in kwargs and issubclass(nullable_type, Connection):
8493 # Let super class raise if type is not a Connection
8594 try:
86 model = type.Edge.node._type._meta.model
87 kwargs.setdefault("sort", sort_argument_for_model(model))
88 except Exception:
89 raise Exception(
95 kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument())
96 except (AttributeError, TypeError):
97 raise TypeError(
9098 'Cannot create sort argument for {}. A model is required. Set the "sort" argument'
9199 " to None to disabling the creation of the sort query argument".format(
92 type.__name__
100 nullable_type.__name__
93101 )
94102 )
95103 elif "sort" in kwargs and kwargs["sort"] is None:
96104 del kwargs["sort"]
97105 super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)
98106
107 @classmethod
108 def get_query(cls, model, info, sort=None, **args):
109 query = get_query(model, info.context)
110 if sort is not None:
111 if isinstance(sort, six.string_types):
112 query = query.order_by(sort.value)
113 else:
114 query = query.order_by(*(col.value for col in sort))
115 return query
99116
100 def default_connection_field_factory(relationship, registry):
117
118 class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
119 """
120 This is currently experimental.
121 The API and behavior may change in future versions.
122 Use at your own risk.
123 """
124
125 def get_resolver(self, parent_resolver):
126 return partial(
127 self.connection_resolver,
128 self.resolver,
129 get_nullable_type(self.type),
130 self.model,
131 )
132
133 @classmethod
134 def from_relationship(cls, relationship, registry, **field_kwargs):
135 model = relationship.mapper.entity
136 model_type = registry.get_type_for_model(model)
137 return cls(model_type.connection, resolver=get_batch_resolver(relationship), **field_kwargs)
138
139
140 def default_connection_field_factory(relationship, registry, **field_kwargs):
101141 model = relationship.mapper.entity
102142 model_type = registry.get_type_for_model(model)
103 return createConnectionField(model_type)
143 return __connectionFactory(model_type, **field_kwargs)
104144
105145
106146 # TODO Remove in next major version
107147 __connectionFactory = UnsortedSQLAlchemyConnectionField
108148
109149
110 def createConnectionField(_type):
111 log.warn(
150 def createConnectionField(_type, **field_kwargs):
151 warnings.warn(
112152 'createConnectionField is deprecated and will be removed in the next '
113 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
153 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.',
154 DeprecationWarning,
114155 )
115 return __connectionFactory(_type)
156 return __connectionFactory(_type, **field_kwargs)
116157
117158
118159 def registerConnectionFieldFactory(factoryMethod):
119 log.warn(
160 warnings.warn(
120161 'registerConnectionFieldFactory is deprecated and will be removed in the next '
121 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
162 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.',
163 DeprecationWarning,
122164 )
123165 global __connectionFactory
124166 __connectionFactory = factoryMethod
125167
126168
127169 def unregisterConnectionFieldFactory():
128 log.warn(
170 warnings.warn(
129171 'registerConnectionFieldFactory is deprecated and will be removed in the next '
130 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
172 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.',
173 DeprecationWarning,
131174 )
132175 global __connectionFactory
133176 __connectionFactory = UnsortedSQLAlchemyConnectionField
177
178
179 def get_nullable_type(_type):
180 if isinstance(_type, NonNull):
181 return _type.of_type
182 return _type
0 from collections import defaultdict
1
2 import six
3 from sqlalchemy.types import Enum as SQLAlchemyEnumType
4
5 from graphene import Enum
6
7
08 class Registry(object):
19 def __init__(self):
210 self._registry = {}
311 self._registry_models = {}
12 self._registry_orm_fields = defaultdict(dict)
413 self._registry_composites = {}
14 self._registry_enums = {}
15 self._registry_sort_enums = {}
516
6 def register(self, cls):
17 def register(self, obj_type):
718 from .types import SQLAlchemyObjectType
819
9 assert issubclass(cls, SQLAlchemyObjectType), (
10 "Only classes of type SQLAlchemyObjectType can be registered, "
11 'received "{}"'
12 ).format(cls.__name__)
13 assert cls._meta.registry == self, "Registry for a Model have to match."
20 if not isinstance(obj_type, type) or not issubclass(
21 obj_type, SQLAlchemyObjectType
22 ):
23 raise TypeError(
24 "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
25 )
26 assert obj_type._meta.registry == self, "Registry for a Model have to match."
1427 # assert self.get_type_for_model(cls._meta.model) in [None, cls], (
1528 # 'SQLAlchemy model "{}" already associated with '
1629 # 'another type "{}".'
1730 # ).format(cls._meta.model, self._registry[cls._meta.model])
18 self._registry[cls._meta.model] = cls
31 self._registry[obj_type._meta.model] = obj_type
1932
2033 def get_type_for_model(self, model):
2134 return self._registry.get(model)
35
36 def register_orm_field(self, obj_type, field_name, orm_field):
37 from .types import SQLAlchemyObjectType
38
39 if not isinstance(obj_type, type) or not issubclass(
40 obj_type, SQLAlchemyObjectType
41 ):
42 raise TypeError(
43 "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
44 )
45 if not field_name or not isinstance(field_name, six.string_types):
46 raise TypeError("Expected a field name, but got: {!r}".format(field_name))
47 self._registry_orm_fields[obj_type][field_name] = orm_field
48
49 def get_orm_field_for_graphene_field(self, obj_type, field_name):
50 return self._registry_orm_fields.get(obj_type, {}).get(field_name)
2251
2352 def register_composite_converter(self, composite, converter):
2453 self._registry_composites[composite] = converter
2554
2655 def get_converter_for_composite(self, composite):
2756 return self._registry_composites.get(composite)
57
58 def register_enum(self, sa_enum, graphene_enum):
59 if not isinstance(sa_enum, SQLAlchemyEnumType):
60 raise TypeError(
61 "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum)
62 )
63 if not isinstance(graphene_enum, type(Enum)):
64 raise TypeError(
65 "Expected Graphene Enum, but got: {!r}".format(graphene_enum)
66 )
67
68 self._registry_enums[sa_enum] = graphene_enum
69
70 def get_graphene_enum_for_sa_enum(self, sa_enum):
71 return self._registry_enums.get(sa_enum)
72
73 def register_sort_enum(self, obj_type, sort_enum):
74 from .types import SQLAlchemyObjectType
75
76 if not isinstance(obj_type, type) or not issubclass(
77 obj_type, SQLAlchemyObjectType
78 ):
79 raise TypeError(
80 "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
81 )
82 if not isinstance(sort_enum, type(Enum)):
83 raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum))
84 self._registry_sort_enums[obj_type] = sort_enum
85
86 def get_sort_enum_for_object_type(self, obj_type):
87 return self._registry_sort_enums.get(obj_type)
2888
2989
3090 registry = None
0 from graphene.utils.get_unbound_function import get_unbound_function
1
2
3 def get_custom_resolver(obj_type, orm_field_name):
4 """
5 Since `graphene` will call `resolve_<field_name>` on a field only if it
6 does not have a `resolver`, we need to re-implement that logic here so
7 users are able to override the default resolvers that we provide.
8 """
9 resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None)
10 if resolver:
11 return get_unbound_function(resolver)
12
13 return None
14
15
16 def get_attr_resolver(obj_type, model_attr):
17 """
18 In order to support field renaming via `ORMField.model_attr`,
19 we need to define resolver functions for each field.
20
21 :param SQLAlchemyObjectType obj_type:
22 :param str model_attr: the name of the SQLAlchemy attribute
23 :rtype: Callable
24 """
25 return lambda root, _info: getattr(root, model_attr, None)
0 import pytest
1 from sqlalchemy import create_engine
2 from sqlalchemy.orm import sessionmaker
3
4 import graphene
5
6 from ..converter import convert_sqlalchemy_composite
7 from ..registry import reset_global_registry
8 from .models import Base, CompositeFullName
9
10 test_db_url = 'sqlite://' # use in-memory database for tests
11
12
13 @pytest.fixture(autouse=True)
14 def reset_registry():
15 reset_global_registry()
16
17 # Prevent tests that implicitly depend on Reporter from raising
18 # Tests that explicitly depend on this behavior should re-register a converter
19 @convert_sqlalchemy_composite.register(CompositeFullName)
20 def convert_composite_class(composite, registry):
21 return graphene.Field(graphene.Int)
22
23
24 @pytest.yield_fixture(scope="function")
25 def session_factory():
26 engine = create_engine(test_db_url)
27 Base.metadata.create_all(engine)
28
29 yield sessionmaker(bind=engine)
30
31 # SQLite in-memory db is deleted when its connection is closed.
32 # https://www.sqlite.org/inmemorydb.html
33 engine.dispose()
34
35
36 @pytest.fixture(scope="function")
37 def session(session_factory):
38 return session_factory()
11
22 import enum
33
4 from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table
4 from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table,
5 func, select)
56 from sqlalchemy.ext.declarative import declarative_base
6 from sqlalchemy.orm import mapper, relationship
7 from sqlalchemy.ext.hybrid import hybrid_property
8 from sqlalchemy.orm import column_property, composite, mapper, relationship
9
10 PetKind = Enum("cat", "dog", name="pet_kind")
711
812
9 class Hairkind(enum.Enum):
13 class HairKind(enum.Enum):
1014 LONG = 'long'
1115 SHORT = 'short'
1216
3135 __tablename__ = "pets"
3236 id = Column(Integer(), primary_key=True)
3337 name = Column(String(30))
34 pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False)
35 hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False)
38 pet_kind = Column(PetKind, nullable=False)
39 hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False)
3640 reporter_id = Column(Integer(), ForeignKey("reporters.id"))
41
42
43 class CompositeFullName(object):
44 def __init__(self, first_name, last_name):
45 self.first_name = first_name
46 self.last_name = last_name
47
48 def __composite_values__(self):
49 return self.first_name, self.last_name
50
51 def __repr__(self):
52 return "{} {}".format(self.first_name, self.last_name)
3753
3854
3955 class Reporter(Base):
4056 __tablename__ = "reporters"
57
4158 id = Column(Integer(), primary_key=True)
42 first_name = Column(String(30))
43 last_name = Column(String(30))
44 email = Column(String())
45 pets = relationship("Pet", secondary=association_table, backref="reporters")
59 first_name = Column(String(30), doc="First name")
60 last_name = Column(String(30), doc="Last name")
61 email = Column(String(), doc="Email")
62 favorite_pet_kind = Column(PetKind)
63 pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id")
4664 articles = relationship("Article", backref="reporter")
4765 favorite_article = relationship("Article", uselist=False)
4866
49 # total = column_property(
50 # select([
51 # func.cast(func.count(PersonInfo.id), Float)
52 # ])
53 # )
67 @hybrid_property
68 def hybrid_prop(self):
69 return self.first_name
70
71 column_prop = column_property(
72 select([func.cast(func.count(id), Integer)]), doc="Column property"
73 )
74
75 composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite")
5476
5577
5678 class Article(Base):
0 import contextlib
1 import logging
2
3 import pytest
4
5 import graphene
6 from graphene import relay
7
8 from ..fields import (BatchSQLAlchemyConnectionField,
9 default_connection_field_factory)
10 from ..types import ORMField, SQLAlchemyObjectType
11 from .models import Article, HairKind, Pet, Reporter
12 from .utils import is_sqlalchemy_version_less_than, to_std_dicts
13
14
15 class MockLoggingHandler(logging.Handler):
16 """Intercept and store log messages in a list."""
17 def __init__(self, *args, **kwargs):
18 self.messages = []
19 logging.Handler.__init__(self, *args, **kwargs)
20
21 def emit(self, record):
22 self.messages.append(record.getMessage())
23
24
25 @contextlib.contextmanager
26 def mock_sqlalchemy_logging_handler():
27 logging.basicConfig()
28 sql_logger = logging.getLogger('sqlalchemy.engine')
29 previous_level = sql_logger.level
30
31 sql_logger.setLevel(logging.INFO)
32 mock_logging_handler = MockLoggingHandler()
33 mock_logging_handler.setLevel(logging.INFO)
34 sql_logger.addHandler(mock_logging_handler)
35
36 yield mock_logging_handler
37
38 sql_logger.setLevel(previous_level)
39
40
41 def get_schema():
42 class ReporterType(SQLAlchemyObjectType):
43 class Meta:
44 model = Reporter
45 interfaces = (relay.Node,)
46 batching = True
47
48 class ArticleType(SQLAlchemyObjectType):
49 class Meta:
50 model = Article
51 interfaces = (relay.Node,)
52 batching = True
53
54 class PetType(SQLAlchemyObjectType):
55 class Meta:
56 model = Pet
57 interfaces = (relay.Node,)
58 batching = True
59
60 class Query(graphene.ObjectType):
61 articles = graphene.Field(graphene.List(ArticleType))
62 reporters = graphene.Field(graphene.List(ReporterType))
63
64 def resolve_articles(self, info):
65 return info.context.get('session').query(Article).all()
66
67 def resolve_reporters(self, info):
68 return info.context.get('session').query(Reporter).all()
69
70 return graphene.Schema(query=Query)
71
72
73 if is_sqlalchemy_version_less_than('1.2'):
74 pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True)
75
76
77 def test_many_to_one(session_factory):
78 session = session_factory()
79
80 reporter_1 = Reporter(
81 first_name='Reporter_1',
82 )
83 session.add(reporter_1)
84 reporter_2 = Reporter(
85 first_name='Reporter_2',
86 )
87 session.add(reporter_2)
88
89 article_1 = Article(headline='Article_1')
90 article_1.reporter = reporter_1
91 session.add(article_1)
92
93 article_2 = Article(headline='Article_2')
94 article_2.reporter = reporter_2
95 session.add(article_2)
96
97 session.commit()
98 session.close()
99
100 schema = get_schema()
101
102 with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
103 # Starts new session to fully reset the engine / connection logging level
104 session = session_factory()
105 result = schema.execute("""
106 query {
107 articles {
108 headline
109 reporter {
110 firstName
111 }
112 }
113 }
114 """, context_value={"session": session})
115 messages = sqlalchemy_logging_handler.messages
116
117 assert len(messages) == 5
118
119 if is_sqlalchemy_version_less_than('1.3'):
120 # The batched SQL statement generated is different in 1.2.x
121 # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
122 # See https://git.io/JewQu
123 sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN reporters' in message]
124 assert len(sql_statements) == 1
125 return
126
127 assert messages == [
128 'BEGIN (implicit)',
129
130 'SELECT articles.id AS articles_id, '
131 'articles.headline AS articles_headline, '
132 'articles.pub_date AS articles_pub_date, '
133 'articles.reporter_id AS articles_reporter_id \n'
134 'FROM articles',
135 '()',
136
137 'SELECT reporters.id AS reporters_id, '
138 '(SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, '
139 'reporters.first_name AS reporters_first_name, '
140 'reporters.last_name AS reporters_last_name, '
141 'reporters.email AS reporters_email, '
142 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n'
143 'FROM reporters \n'
144 'WHERE reporters.id IN (?, ?)',
145 '(1, 2)',
146 ]
147
148 assert not result.errors
149 result = to_std_dicts(result.data)
150 assert result == {
151 "articles": [
152 {
153 "headline": "Article_1",
154 "reporter": {
155 "firstName": "Reporter_1",
156 },
157 },
158 {
159 "headline": "Article_2",
160 "reporter": {
161 "firstName": "Reporter_2",
162 },
163 },
164 ],
165 }
166
167
168 def test_one_to_one(session_factory):
169 session = session_factory()
170
171 reporter_1 = Reporter(
172 first_name='Reporter_1',
173 )
174 session.add(reporter_1)
175 reporter_2 = Reporter(
176 first_name='Reporter_2',
177 )
178 session.add(reporter_2)
179
180 article_1 = Article(headline='Article_1')
181 article_1.reporter = reporter_1
182 session.add(article_1)
183
184 article_2 = Article(headline='Article_2')
185 article_2.reporter = reporter_2
186 session.add(article_2)
187
188 session.commit()
189 session.close()
190
191 schema = get_schema()
192
193 with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
194 # Starts new session to fully reset the engine / connection logging level
195 session = session_factory()
196 result = schema.execute("""
197 query {
198 reporters {
199 firstName
200 favoriteArticle {
201 headline
202 }
203 }
204 }
205 """, context_value={"session": session})
206 messages = sqlalchemy_logging_handler.messages
207
208 assert len(messages) == 5
209
210 if is_sqlalchemy_version_less_than('1.3'):
211 # The batched SQL statement generated is different in 1.2.x
212 # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
213 # See https://git.io/JewQu
214 sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message]
215 assert len(sql_statements) == 1
216 return
217
218 assert messages == [
219 'BEGIN (implicit)',
220
221 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, '
222 'reporters.id AS reporters_id, '
223 'reporters.first_name AS reporters_first_name, '
224 'reporters.last_name AS reporters_last_name, '
225 'reporters.email AS reporters_email, '
226 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n'
227 'FROM reporters',
228 '()',
229
230 'SELECT articles.reporter_id AS articles_reporter_id, '
231 'articles.id AS articles_id, '
232 'articles.headline AS articles_headline, '
233 'articles.pub_date AS articles_pub_date \n'
234 'FROM articles \n'
235 'WHERE articles.reporter_id IN (?, ?)',
236 '(1, 2)'
237 ]
238
239 assert not result.errors
240 result = to_std_dicts(result.data)
241 assert result == {
242 "reporters": [
243 {
244 "firstName": "Reporter_1",
245 "favoriteArticle": {
246 "headline": "Article_1",
247 },
248 },
249 {
250 "firstName": "Reporter_2",
251 "favoriteArticle": {
252 "headline": "Article_2",
253 },
254 },
255 ],
256 }
257
258
259 def test_one_to_many(session_factory):
260 session = session_factory()
261
262 reporter_1 = Reporter(
263 first_name='Reporter_1',
264 )
265 session.add(reporter_1)
266 reporter_2 = Reporter(
267 first_name='Reporter_2',
268 )
269 session.add(reporter_2)
270
271 article_1 = Article(headline='Article_1')
272 article_1.reporter = reporter_1
273 session.add(article_1)
274
275 article_2 = Article(headline='Article_2')
276 article_2.reporter = reporter_1
277 session.add(article_2)
278
279 article_3 = Article(headline='Article_3')
280 article_3.reporter = reporter_2
281 session.add(article_3)
282
283 article_4 = Article(headline='Article_4')
284 article_4.reporter = reporter_2
285 session.add(article_4)
286
287 session.commit()
288 session.close()
289
290 schema = get_schema()
291
292 with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
293 # Starts new session to fully reset the engine / connection logging level
294 session = session_factory()
295 result = schema.execute("""
296 query {
297 reporters {
298 firstName
299 articles(first: 2) {
300 edges {
301 node {
302 headline
303 }
304 }
305 }
306 }
307 }
308 """, context_value={"session": session})
309 messages = sqlalchemy_logging_handler.messages
310
311 assert len(messages) == 5
312
313 if is_sqlalchemy_version_less_than('1.3'):
314 # The batched SQL statement generated is different in 1.2.x
315 # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
316 # See https://git.io/JewQu
317 sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message]
318 assert len(sql_statements) == 1
319 return
320
321 assert messages == [
322 'BEGIN (implicit)',
323
324 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, '
325 'reporters.id AS reporters_id, '
326 'reporters.first_name AS reporters_first_name, '
327 'reporters.last_name AS reporters_last_name, '
328 'reporters.email AS reporters_email, '
329 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n'
330 'FROM reporters',
331 '()',
332
333 'SELECT articles.reporter_id AS articles_reporter_id, '
334 'articles.id AS articles_id, '
335 'articles.headline AS articles_headline, '
336 'articles.pub_date AS articles_pub_date \n'
337 'FROM articles \n'
338 'WHERE articles.reporter_id IN (?, ?)',
339 '(1, 2)'
340 ]
341
342 assert not result.errors
343 result = to_std_dicts(result.data)
344 assert result == {
345 "reporters": [
346 {
347 "firstName": "Reporter_1",
348 "articles": {
349 "edges": [
350 {
351 "node": {
352 "headline": "Article_1",
353 },
354 },
355 {
356 "node": {
357 "headline": "Article_2",
358 },
359 },
360 ],
361 },
362 },
363 {
364 "firstName": "Reporter_2",
365 "articles": {
366 "edges": [
367 {
368 "node": {
369 "headline": "Article_3",
370 },
371 },
372 {
373 "node": {
374 "headline": "Article_4",
375 },
376 },
377 ],
378 },
379 },
380 ],
381 }
382
383
384 def test_many_to_many(session_factory):
385 session = session_factory()
386
387 reporter_1 = Reporter(
388 first_name='Reporter_1',
389 )
390 session.add(reporter_1)
391 reporter_2 = Reporter(
392 first_name='Reporter_2',
393 )
394 session.add(reporter_2)
395
396 pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG)
397 session.add(pet_1)
398
399 pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG)
400 session.add(pet_2)
401
402 reporter_1.pets.append(pet_1)
403 reporter_1.pets.append(pet_2)
404
405 pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG)
406 session.add(pet_3)
407
408 pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG)
409 session.add(pet_4)
410
411 reporter_2.pets.append(pet_3)
412 reporter_2.pets.append(pet_4)
413
414 session.commit()
415 session.close()
416
417 schema = get_schema()
418
419 with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
420 # Starts new session to fully reset the engine / connection logging level
421 session = session_factory()
422 result = schema.execute("""
423 query {
424 reporters {
425 firstName
426 pets(first: 2) {
427 edges {
428 node {
429 name
430 }
431 }
432 }
433 }
434 }
435 """, context_value={"session": session})
436 messages = sqlalchemy_logging_handler.messages
437
438 assert len(messages) == 5
439
440 if is_sqlalchemy_version_less_than('1.3'):
441 # The batched SQL statement generated is different in 1.2.x
442 # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
443 # See https://git.io/JewQu
444 sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN pets' in message]
445 assert len(sql_statements) == 1
446 return
447
448 assert messages == [
449 'BEGIN (implicit)',
450
451 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, '
452 'reporters.id AS reporters_id, '
453 'reporters.first_name AS reporters_first_name, '
454 'reporters.last_name AS reporters_last_name, '
455 'reporters.email AS reporters_email, '
456 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n'
457 'FROM reporters',
458 '()',
459
460 'SELECT reporters_1.id AS reporters_1_id, '
461 'pets.id AS pets_id, '
462 'pets.name AS pets_name, '
463 'pets.pet_kind AS pets_pet_kind, '
464 'pets.hair_kind AS pets_hair_kind, '
465 'pets.reporter_id AS pets_reporter_id \n'
466 'FROM reporters AS reporters_1 '
467 'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id '
468 'JOIN pets ON pets.id = association_1.pet_id \n'
469 'WHERE reporters_1.id IN (?, ?) '
470 'ORDER BY pets.id',
471 '(1, 2)'
472 ]
473
474 assert not result.errors
475 result = to_std_dicts(result.data)
476 assert result == {
477 "reporters": [
478 {
479 "firstName": "Reporter_1",
480 "pets": {
481 "edges": [
482 {
483 "node": {
484 "name": "Pet_1",
485 },
486 },
487 {
488 "node": {
489 "name": "Pet_2",
490 },
491 },
492 ],
493 },
494 },
495 {
496 "firstName": "Reporter_2",
497 "pets": {
498 "edges": [
499 {
500 "node": {
501 "name": "Pet_3",
502 },
503 },
504 {
505 "node": {
506 "name": "Pet_4",
507 },
508 },
509 ],
510 },
511 },
512 ],
513 }
514
515
516 def test_disable_batching_via_ormfield(session_factory):
517 session = session_factory()
518 reporter_1 = Reporter(first_name='Reporter_1')
519 session.add(reporter_1)
520 reporter_2 = Reporter(first_name='Reporter_2')
521 session.add(reporter_2)
522 session.commit()
523 session.close()
524
525 class ReporterType(SQLAlchemyObjectType):
526 class Meta:
527 model = Reporter
528 interfaces = (relay.Node,)
529 batching = True
530
531 favorite_article = ORMField(batching=False)
532 articles = ORMField(batching=False)
533
534 class ArticleType(SQLAlchemyObjectType):
535 class Meta:
536 model = Article
537 interfaces = (relay.Node,)
538
539 class Query(graphene.ObjectType):
540 reporters = graphene.Field(graphene.List(ReporterType))
541
542 def resolve_reporters(self, info):
543 return info.context.get('session').query(Reporter).all()
544
545 schema = graphene.Schema(query=Query)
546
547 # Test one-to-one and many-to-one relationships
548 with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
549 # Starts new session to fully reset the engine / connection logging level
550 session = session_factory()
551 schema.execute("""
552 query {
553 reporters {
554 favoriteArticle {
555 headline
556 }
557 }
558 }
559 """, context_value={"session": session})
560 messages = sqlalchemy_logging_handler.messages
561
562 select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
563 assert len(select_statements) == 2
564
565 # Test one-to-many and many-to-many relationships
566 with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
567 # Starts new session to fully reset the engine / connection logging level
568 session = session_factory()
569 schema.execute("""
570 query {
571 reporters {
572 articles {
573 edges {
574 node {
575 headline
576 }
577 }
578 }
579 }
580 }
581 """, context_value={"session": session})
582 messages = sqlalchemy_logging_handler.messages
583
584 select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
585 assert len(select_statements) == 2
586
587
588 def test_connection_factory_field_overrides_batching_is_false(session_factory):
589 session = session_factory()
590 reporter_1 = Reporter(first_name='Reporter_1')
591 session.add(reporter_1)
592 reporter_2 = Reporter(first_name='Reporter_2')
593 session.add(reporter_2)
594 session.commit()
595 session.close()
596
597 class ReporterType(SQLAlchemyObjectType):
598 class Meta:
599 model = Reporter
600 interfaces = (relay.Node,)
601 batching = False
602 connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
603
604 articles = ORMField(batching=False)
605
606 class ArticleType(SQLAlchemyObjectType):
607 class Meta:
608 model = Article
609 interfaces = (relay.Node,)
610
611 class Query(graphene.ObjectType):
612 reporters = graphene.Field(graphene.List(ReporterType))
613
614 def resolve_reporters(self, info):
615 return info.context.get('session').query(Reporter).all()
616
617 schema = graphene.Schema(query=Query)
618
619 with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
620 # Starts new session to fully reset the engine / connection logging level
621 session = session_factory()
622 schema.execute("""
623 query {
624 reporters {
625 articles {
626 edges {
627 node {
628 headline
629 }
630 }
631 }
632 }
633 }
634 """, context_value={"session": session})
635 messages = sqlalchemy_logging_handler.messages
636
637 if is_sqlalchemy_version_less_than('1.3'):
638 # The batched SQL statement generated is different in 1.2.x
639 # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin`
640 # See https://git.io/JewQu
641 select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message]
642 else:
643 select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
644 assert len(select_statements) == 1
645
646
647 def test_connection_factory_field_overrides_batching_is_true(session_factory):
648 session = session_factory()
649 reporter_1 = Reporter(first_name='Reporter_1')
650 session.add(reporter_1)
651 reporter_2 = Reporter(first_name='Reporter_2')
652 session.add(reporter_2)
653 session.commit()
654 session.close()
655
656 class ReporterType(SQLAlchemyObjectType):
657 class Meta:
658 model = Reporter
659 interfaces = (relay.Node,)
660 batching = True
661 connection_field_factory = default_connection_field_factory
662
663 articles = ORMField(batching=True)
664
665 class ArticleType(SQLAlchemyObjectType):
666 class Meta:
667 model = Article
668 interfaces = (relay.Node,)
669
670 class Query(graphene.ObjectType):
671 reporters = graphene.Field(graphene.List(ReporterType))
672
673 def resolve_reporters(self, info):
674 return info.context.get('session').query(Reporter).all()
675
676 schema = graphene.Schema(query=Query)
677
678 with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
679 # Starts new session to fully reset the engine / connection logging level
680 session = session_factory()
681 schema.execute("""
682 query {
683 reporters {
684 articles {
685 edges {
686 node {
687 headline
688 }
689 }
690 }
691 }
692 }
693 """, context_value={"session": session})
694 messages = sqlalchemy_logging_handler.messages
695
696 select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
697 assert len(select_statements) == 2
0 import pytest
1 from graphql.backend import GraphQLCachedBackend, GraphQLCoreBackend
2
3 import graphene
4 from graphene import relay
5
6 from ..fields import BatchSQLAlchemyConnectionField
7 from ..types import SQLAlchemyObjectType
8 from .models import Article, HairKind, Pet, Reporter
9 from .utils import is_sqlalchemy_version_less_than
10
11 if is_sqlalchemy_version_less_than('1.2'):
12 pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True)
13
14
15 def get_schema():
16 class ReporterType(SQLAlchemyObjectType):
17 class Meta:
18 model = Reporter
19 interfaces = (relay.Node,)
20 connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
21
22 class ArticleType(SQLAlchemyObjectType):
23 class Meta:
24 model = Article
25 interfaces = (relay.Node,)
26 connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
27
28 class PetType(SQLAlchemyObjectType):
29 class Meta:
30 model = Pet
31 interfaces = (relay.Node,)
32 connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
33
34 class Query(graphene.ObjectType):
35 articles = graphene.Field(graphene.List(ArticleType))
36 reporters = graphene.Field(graphene.List(ReporterType))
37
38 def resolve_articles(self, info):
39 return info.context.get('session').query(Article).all()
40
41 def resolve_reporters(self, info):
42 return info.context.get('session').query(Reporter).all()
43
44 return graphene.Schema(query=Query)
45
46
47 def benchmark_query(session_factory, benchmark, query):
48 schema = get_schema()
49 cached_backend = GraphQLCachedBackend(GraphQLCoreBackend())
50 cached_backend.document_from_string(schema, query) # Prime cache
51
52 @benchmark
53 def execute_query():
54 result = schema.execute(
55 query,
56 context_value={"session": session_factory()},
57 backend=cached_backend,
58 )
59 assert not result.errors
60
61
62 def test_one_to_one(session_factory, benchmark):
63 session = session_factory()
64
65 reporter_1 = Reporter(
66 first_name='Reporter_1',
67 )
68 session.add(reporter_1)
69 reporter_2 = Reporter(
70 first_name='Reporter_2',
71 )
72 session.add(reporter_2)
73
74 article_1 = Article(headline='Article_1')
75 article_1.reporter = reporter_1
76 session.add(article_1)
77
78 article_2 = Article(headline='Article_2')
79 article_2.reporter = reporter_2
80 session.add(article_2)
81
82 session.commit()
83 session.close()
84
85 benchmark_query(session_factory, benchmark, """
86 query {
87 reporters {
88 firstName
89 favoriteArticle {
90 headline
91 }
92 }
93 }
94 """)
95
96
97 def test_many_to_one(session_factory, benchmark):
98 session = session_factory()
99
100 reporter_1 = Reporter(
101 first_name='Reporter_1',
102 )
103 session.add(reporter_1)
104 reporter_2 = Reporter(
105 first_name='Reporter_2',
106 )
107 session.add(reporter_2)
108
109 article_1 = Article(headline='Article_1')
110 article_1.reporter = reporter_1
111 session.add(article_1)
112
113 article_2 = Article(headline='Article_2')
114 article_2.reporter = reporter_2
115 session.add(article_2)
116
117 session.commit()
118 session.close()
119
120 benchmark_query(session_factory, benchmark, """
121 query {
122 articles {
123 headline
124 reporter {
125 firstName
126 }
127 }
128 }
129 """)
130
131
132 def test_one_to_many(session_factory, benchmark):
133 session = session_factory()
134
135 reporter_1 = Reporter(
136 first_name='Reporter_1',
137 )
138 session.add(reporter_1)
139 reporter_2 = Reporter(
140 first_name='Reporter_2',
141 )
142 session.add(reporter_2)
143
144 article_1 = Article(headline='Article_1')
145 article_1.reporter = reporter_1
146 session.add(article_1)
147
148 article_2 = Article(headline='Article_2')
149 article_2.reporter = reporter_1
150 session.add(article_2)
151
152 article_3 = Article(headline='Article_3')
153 article_3.reporter = reporter_2
154 session.add(article_3)
155
156 article_4 = Article(headline='Article_4')
157 article_4.reporter = reporter_2
158 session.add(article_4)
159
160 session.commit()
161 session.close()
162
163 benchmark_query(session_factory, benchmark, """
164 query {
165 reporters {
166 firstName
167 articles(first: 2) {
168 edges {
169 node {
170 headline
171 }
172 }
173 }
174 }
175 }
176 """)
177
178
179 def test_many_to_many(session_factory, benchmark):
180 session = session_factory()
181
182 reporter_1 = Reporter(
183 first_name='Reporter_1',
184 )
185 session.add(reporter_1)
186 reporter_2 = Reporter(
187 first_name='Reporter_2',
188 )
189 session.add(reporter_2)
190
191 pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG)
192 session.add(pet_1)
193
194 pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG)
195 session.add(pet_2)
196
197 reporter_1.pets.append(pet_1)
198 reporter_1.pets.append(pet_2)
199
200 pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG)
201 session.add(pet_3)
202
203 pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG)
204 session.add(pet_4)
205
206 reporter_2.pets.append(pet_3)
207 reporter_2.pets.append(pet_4)
208
209 session.commit()
210 session.close()
211
212 benchmark_query(session_factory, benchmark, """
213 query {
214 reporters {
215 firstName
216 pets(first: 2) {
217 edges {
218 node {
219 name
220 }
221 }
222 }
223 }
224 }
225 """)
00 import enum
11
2 from py.test import raises
3 from sqlalchemy import Column, Table, case, func, select, types
2 import pytest
3 from sqlalchemy import Column, func, select, types
44 from sqlalchemy.dialects import postgresql
55 from sqlalchemy.ext.declarative import declarative_base
6 from sqlalchemy.inspection import inspect
67 from sqlalchemy.orm import column_property, composite
7 from sqlalchemy.sql.elements import Label
88 from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType
99
1010 import graphene
1717 convert_sqlalchemy_relationship)
1818 from ..fields import (UnsortedSQLAlchemyConnectionField,
1919 default_connection_field_factory)
20 from ..registry import Registry
20 from ..registry import Registry, get_global_registry
2121 from ..types import SQLAlchemyObjectType
22 from .models import Article, Pet, Reporter
23
24
25 def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs):
26 column = Column(sqlalchemy_type, doc="Custom Help Text", **kwargs)
27 graphene_type = convert_sqlalchemy_column(column)
28 assert isinstance(graphene_type, graphene_field)
29 field = (
30 graphene_type
31 if isinstance(graphene_type, graphene.Field)
32 else graphene_type.Field()
33 )
34 assert field.description == "Custom Help Text"
35 return field
36
37
38 def assert_composite_conversion(
39 composite_class, composite_columns, graphene_field, registry, **kwargs
40 ):
41 composite_column = composite(
42 composite_class, *composite_columns, doc="Custom Help Text", **kwargs
43 )
44 graphene_type = convert_sqlalchemy_composite(composite_column, registry)
45 assert isinstance(graphene_type, graphene_field)
46 field = graphene_type.Field()
47 # SQLAlchemy currently does not persist the doc onto the column, even though
48 # the documentation says it does....
49 # assert field.description == 'Custom Help Text'
50 return field
22 from .models import Article, CompositeFullName, Pet, Reporter
23
24
25 def mock_resolver():
26 pass
27
28
29 def get_field(sqlalchemy_type, **column_kwargs):
30 class Model(declarative_base()):
31 __tablename__ = 'model'
32 id_ = Column(types.Integer, primary_key=True)
33 column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs)
34
35 column_prop = inspect(Model).column_attrs['column']
36 return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver)
37
38
39 def get_field_from_column(column_):
40 class Model(declarative_base()):
41 __tablename__ = 'model'
42 id_ = Column(types.Integer, primary_key=True)
43 column = column_
44
45 column_prop = inspect(Model).column_attrs['column']
46 return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver)
5147
5248
5349 def test_should_unknown_sqlalchemy_field_raise_exception():
54 with raises(Exception) as excinfo:
55 convert_sqlalchemy_column(None)
56 assert "Don't know how to convert the SQLAlchemy field" in str(excinfo.value)
50 re_err = "Don't know how to convert the SQLAlchemy field"
51 with pytest.raises(Exception, match=re_err):
52 # support legacy Binary type and subsequent LargeBinary
53 get_field(getattr(types, 'LargeBinary', types.Binary)())
5754
5855
5956 def test_should_date_convert_string():
60 assert_column_conversion(types.Date(), graphene.String)
61
62
63 def test_should_datetime_convert_string():
64 assert_column_conversion(types.DateTime(), DateTime)
57 assert get_field(types.Date()).type == graphene.String
58
59
60 def test_should_datetime_convert_datetime():
61 assert get_field(types.DateTime()).type == DateTime
6562
6663
6764 def test_should_time_convert_string():
68 assert_column_conversion(types.Time(), graphene.String)
65 assert get_field(types.Time()).type == graphene.String
6966
7067
7168 def test_should_string_convert_string():
72 assert_column_conversion(types.String(), graphene.String)
69 assert get_field(types.String()).type == graphene.String
7370
7471
7572 def test_should_text_convert_string():
76 assert_column_conversion(types.Text(), graphene.String)
73 assert get_field(types.Text()).type == graphene.String
7774
7875
7976 def test_should_unicode_convert_string():
80 assert_column_conversion(types.Unicode(), graphene.String)
77 assert get_field(types.Unicode()).type == graphene.String
8178
8279
8380 def test_should_unicodetext_convert_string():
84 assert_column_conversion(types.UnicodeText(), graphene.String)
81 assert get_field(types.UnicodeText()).type == graphene.String
8582
8683
8784 def test_should_enum_convert_enum():
88 field = assert_column_conversion(
89 types.Enum(enum.Enum("one", "two")), graphene.Field
90 )
85 field = get_field(types.Enum(enum.Enum("TwoNumbers", ("one", "two"))))
9186 field_type = field.type()
9287 assert isinstance(field_type, graphene.Enum)
93 assert hasattr(field_type, "two")
94 field = assert_column_conversion(
95 types.Enum("one", "two", name="two_numbers"), graphene.Field
96 )
88 assert field_type._meta.name == "TwoNumbers"
89 assert hasattr(field_type, "ONE")
90 assert not hasattr(field_type, "one")
91 assert hasattr(field_type, "TWO")
92 assert not hasattr(field_type, "two")
93
94 field = get_field(types.Enum("one", "two", name="two_numbers"))
9795 field_type = field.type()
98 assert field_type.__class__.__name__ == "two_numbers"
9996 assert isinstance(field_type, graphene.Enum)
100 assert hasattr(field_type, "two")
97 assert field_type._meta.name == "TwoNumbers"
98 assert hasattr(field_type, "ONE")
99 assert not hasattr(field_type, "one")
100 assert hasattr(field_type, "TWO")
101 assert not hasattr(field_type, "two")
102
103
104 def test_should_not_enum_convert_enum_without_name():
105 field = get_field(types.Enum("one", "two"))
106 re_err = r"No type name specified for Enum\('one', 'two'\)"
107 with pytest.raises(TypeError, match=re_err):
108 field.type()
101109
102110
103111 def test_should_small_integer_convert_int():
104 assert_column_conversion(types.SmallInteger(), graphene.Int)
112 assert get_field(types.SmallInteger()).type == graphene.Int
105113
106114
107115 def test_should_big_integer_convert_int():
108 assert_column_conversion(types.BigInteger(), graphene.Float)
116 assert get_field(types.BigInteger()).type == graphene.Float
109117
110118
111119 def test_should_integer_convert_int():
112 assert_column_conversion(types.Integer(), graphene.Int)
113
114
115 def test_should_integer_convert_id():
116 assert_column_conversion(types.Integer(), graphene.ID, primary_key=True)
120 assert get_field(types.Integer()).type == graphene.Int
121
122
123 def test_should_primary_integer_convert_id():
124 assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull(graphene.ID)
117125
118126
119127 def test_should_boolean_convert_boolean():
120 assert_column_conversion(types.Boolean(), graphene.Boolean)
128 assert get_field(types.Boolean()).type == graphene.Boolean
121129
122130
123131 def test_should_float_convert_float():
124 assert_column_conversion(types.Float(), graphene.Float)
132 assert get_field(types.Float()).type == graphene.Float
125133
126134
127135 def test_should_numeric_convert_float():
128 assert_column_conversion(types.Numeric(), graphene.Float)
129
130
131 def test_should_label_convert_string():
132 label = Label("label_test", case([], else_="foo"), type_=types.Unicode())
133 graphene_type = convert_sqlalchemy_column(label)
134 assert isinstance(graphene_type, graphene.String)
135
136
137 def test_should_label_convert_int():
138 label = Label("int_label_test", case([], else_="foo"), type_=types.Integer())
139 graphene_type = convert_sqlalchemy_column(label)
140 assert isinstance(graphene_type, graphene.Int)
136 assert get_field(types.Numeric()).type == graphene.Float
141137
142138
143139 def test_should_choice_convert_enum():
144 TYPES = [(u"es", u"Spanish"), (u"en", u"English")]
145 column = Column(ChoiceType(TYPES), doc="Language", name="language")
146 Base = declarative_base()
147
148 Table("translatedmodel", Base.metadata, column)
149 graphene_type = convert_sqlalchemy_column(column)
140 field = get_field(ChoiceType([(u"es", u"Spanish"), (u"en", u"English")]))
141 graphene_type = field.type
150142 assert issubclass(graphene_type, graphene.Enum)
151 assert graphene_type._meta.name == "TRANSLATEDMODEL_LANGUAGE"
152 assert graphene_type._meta.description == "Language"
143 assert graphene_type._meta.name == "MODEL_COLUMN"
153144 assert graphene_type._meta.enum.__members__["es"].value == "Spanish"
154145 assert graphene_type._meta.enum.__members__["en"].value == "English"
155146
156147
148 def test_should_enum_choice_convert_enum():
149 class TestEnum(enum.Enum):
150 es = u"Spanish"
151 en = u"English"
152
153 field = get_field(ChoiceType(TestEnum, impl=types.String()))
154 graphene_type = field.type
155 assert issubclass(graphene_type, graphene.Enum)
156 assert graphene_type._meta.name == "MODEL_COLUMN"
157 assert graphene_type._meta.enum.__members__["es"].value == "Spanish"
158 assert graphene_type._meta.enum.__members__["en"].value == "English"
159
160
161 def test_should_intenum_choice_convert_enum():
162 class TestEnum(enum.IntEnum):
163 one = 1
164 two = 2
165
166 field = get_field(ChoiceType(TestEnum, impl=types.String()))
167 graphene_type = field.type
168 assert issubclass(graphene_type, graphene.Enum)
169 assert graphene_type._meta.name == "MODEL_COLUMN"
170 assert graphene_type._meta.enum.__members__["one"].value == 1
171 assert graphene_type._meta.enum.__members__["two"].value == 2
172
173
157174 def test_should_columproperty_convert():
158
159 Base = declarative_base()
160
161 class Test(Base):
162 __tablename__ = "test"
163 id = Column(types.Integer, primary_key=True)
164 column = column_property(
165 select([func.sum(func.cast(id, types.Integer))]).where(id == 1)
166 )
167
168 graphene_type = convert_sqlalchemy_column(Test.column)
169 assert not graphene_type.kwargs["required"]
175 field = get_field_from_column(column_property(
176 select([func.sum(func.cast(id, types.Integer))]).where(id == 1)
177 ))
178
179 assert field.type == graphene.Int
170180
171181
172182 def test_should_scalar_list_convert_list():
173 assert_column_conversion(ScalarListType(), graphene.List)
183 field = get_field(ScalarListType())
184 assert isinstance(field.type, graphene.List)
185 assert field.type.of_type == graphene.String
174186
175187
176188 def test_should_jsontype_convert_jsonstring():
177 assert_column_conversion(JSONType(), JSONString)
189 assert get_field(JSONType()).type == JSONString
178190
179191
180192 def test_should_manytomany_convert_connectionorlist():
181 registry = Registry()
182 dynamic_field = convert_sqlalchemy_relationship(
183 Reporter.pets.property, registry, default_connection_field_factory
193 class A(SQLAlchemyObjectType):
194 class Meta:
195 model = Article
196
197 dynamic_field = convert_sqlalchemy_relationship(
198 Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name',
184199 )
185200 assert isinstance(dynamic_field, graphene.Dynamic)
186201 assert not dynamic_field.get_type()
192207 model = Pet
193208
194209 dynamic_field = convert_sqlalchemy_relationship(
195 Reporter.pets.property, A._meta.registry, default_connection_field_factory
210 Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name',
196211 )
197212 assert isinstance(dynamic_field, graphene.Dynamic)
198213 graphene_type = dynamic_field.get_type()
208223 interfaces = (Node,)
209224
210225 dynamic_field = convert_sqlalchemy_relationship(
211 Reporter.pets.property, A._meta.registry, default_connection_field_factory
226 Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name',
212227 )
213228 assert isinstance(dynamic_field, graphene.Dynamic)
214229 assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField)
215230
216231
217232 def test_should_manytoone_convert_connectionorlist():
218 registry = Registry()
219 dynamic_field = convert_sqlalchemy_relationship(
220 Article.reporter.property, registry, default_connection_field_factory
233 class A(SQLAlchemyObjectType):
234 class Meta:
235 model = Article
236
237 dynamic_field = convert_sqlalchemy_relationship(
238 Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name',
221239 )
222240 assert isinstance(dynamic_field, graphene.Dynamic)
223241 assert not dynamic_field.get_type()
229247 model = Reporter
230248
231249 dynamic_field = convert_sqlalchemy_relationship(
232 Article.reporter.property, A._meta.registry, default_connection_field_factory
250 Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name',
233251 )
234252 assert isinstance(dynamic_field, graphene.Dynamic)
235253 graphene_type = dynamic_field.get_type()
244262 interfaces = (Node,)
245263
246264 dynamic_field = convert_sqlalchemy_relationship(
247 Article.reporter.property, A._meta.registry, default_connection_field_factory
265 Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name',
248266 )
249267 assert isinstance(dynamic_field, graphene.Dynamic)
250268 graphene_type = dynamic_field.get_type()
259277 interfaces = (Node,)
260278
261279 dynamic_field = convert_sqlalchemy_relationship(
262 Reporter.favorite_article.property, A._meta.registry, default_connection_field_factory
280 Reporter.favorite_article.property, A, default_connection_field_factory, True, 'orm_field_name',
263281 )
264282 assert isinstance(dynamic_field, graphene.Dynamic)
265283 graphene_type = dynamic_field.get_type()
268286
269287
270288 def test_should_postgresql_uuid_convert():
271 assert_column_conversion(postgresql.UUID(), graphene.String)
289 assert get_field(postgresql.UUID()).type == graphene.String
272290
273291
274292 def test_should_postgresql_enum_convert():
275 field = assert_column_conversion(
276 postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field
277 )
293 field = get_field(postgresql.ENUM("one", "two", name="two_numbers"))
278294 field_type = field.type()
279 assert field_type.__class__.__name__ == "two_numbers"
280295 assert isinstance(field_type, graphene.Enum)
281 assert hasattr(field_type, "two")
296 assert field_type._meta.name == "TwoNumbers"
297 assert hasattr(field_type, "ONE")
298 assert not hasattr(field_type, "one")
299 assert hasattr(field_type, "TWO")
300 assert not hasattr(field_type, "two")
282301
283302
284303 def test_should_postgresql_py_enum_convert():
285 field = assert_column_conversion(
286 postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field
287 )
304 field = get_field(postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"))
288305 field_type = field.type()
289 assert field_type.__class__.__name__ == "TwoNumbers"
306 assert field_type._meta.name == "TwoNumbers"
290307 assert isinstance(field_type, graphene.Enum)
291 assert hasattr(field_type, "two")
308 assert hasattr(field_type, "ONE")
309 assert not hasattr(field_type, "one")
310 assert hasattr(field_type, "TWO")
311 assert not hasattr(field_type, "two")
292312
293313
294314 def test_should_postgresql_array_convert():
295 assert_column_conversion(postgresql.ARRAY(types.Integer), graphene.List)
315 field = get_field(postgresql.ARRAY(types.Integer))
316 assert isinstance(field.type, graphene.List)
317 assert field.type.of_type == graphene.Int
318
319
320 def test_should_array_convert():
321 field = get_field(types.ARRAY(types.Integer))
322 assert isinstance(field.type, graphene.List)
323 assert field.type.of_type == graphene.Int
296324
297325
298326 def test_should_postgresql_json_convert():
299 assert_column_conversion(postgresql.JSON(), JSONString)
327 assert get_field(postgresql.JSON()).type == graphene.JSONString
300328
301329
302330 def test_should_postgresql_jsonb_convert():
303 assert_column_conversion(postgresql.JSONB(), JSONString)
331 assert get_field(postgresql.JSONB()).type == graphene.JSONString
304332
305333
306334 def test_should_postgresql_hstore_convert():
307 assert_column_conversion(postgresql.HSTORE(), JSONString)
335 assert get_field(postgresql.HSTORE()).type == graphene.JSONString
308336
309337
310338 def test_should_composite_convert():
311 class CompositeClass(object):
339 registry = Registry()
340
341 class CompositeClass:
312342 def __init__(self, col1, col2):
313343 self.col1 = col1
314344 self.col2 = col2
315345
316 registry = Registry()
317
318346 @convert_sqlalchemy_composite.register(CompositeClass, registry)
319347 def convert_composite_class(composite, registry):
320348 return graphene.String(description=composite.doc)
321349
322 assert_composite_conversion(
323 CompositeClass,
324 (Column(types.Unicode(50)), Column(types.Unicode(50))),
325 graphene.String,
350 field = convert_sqlalchemy_composite(
351 composite(CompositeClass, (Column(types.Unicode(50)), Column(types.Unicode(50))), doc="Custom Help Text"),
326352 registry,
327 )
353 mock_resolver,
354 )
355 assert isinstance(field, graphene.String)
328356
329357
330358 def test_should_unknown_sqlalchemy_composite_raise_exception():
331 registry = Registry()
332
333 with raises(Exception) as excinfo:
334
335 class CompositeClass(object):
336 def __init__(self, col1, col2):
337 self.col1 = col1
338 self.col2 = col2
339
340 assert_composite_conversion(
341 CompositeClass,
342 (Column(types.Unicode(50)), Column(types.Unicode(50))),
343 graphene.String,
344 registry,
359 class CompositeClass:
360 def __init__(self, col1, col2):
361 self.col1 = col1
362 self.col2 = col2
363
364 re_err = "Don't know how to convert the composite field"
365 with pytest.raises(Exception, match=re_err):
366 convert_sqlalchemy_composite(
367 composite(CompositeFullName, (Column(types.Unicode(50)), Column(types.Unicode(50)))),
368 Registry(),
369 mock_resolver,
345370 )
346
347 assert "Don't know how to convert the composite field" in str(excinfo.value)
0 from enum import Enum as PyEnum
1
2 import pytest
3 from sqlalchemy.types import Enum as SQLAlchemyEnumType
4
5 from graphene import Enum
6
7 from ..enums import _convert_sa_to_graphene_enum, enum_for_field
8 from ..types import SQLAlchemyObjectType
9 from .models import HairKind, Pet
10
11
12 def test_convert_sa_to_graphene_enum_bad_type():
13 re_err = "Expected sqlalchemy.types.Enum, but got: 'foo'"
14 with pytest.raises(TypeError, match=re_err):
15 _convert_sa_to_graphene_enum("foo")
16
17
18 def test_convert_sa_to_graphene_enum_based_on_py_enum():
19 class Color(PyEnum):
20 RED = 1
21 GREEN = 2
22 BLUE = 3
23
24 sa_enum = SQLAlchemyEnumType(Color)
25 graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName")
26 assert isinstance(graphene_enum, type(Enum))
27 assert graphene_enum._meta.name == "Color"
28 assert graphene_enum._meta.enum is Color
29
30
31 def test_convert_sa_to_graphene_enum_based_on_py_enum_with_bad_names():
32 class Color(PyEnum):
33 red = 1
34 green = 2
35 blue = 3
36
37 sa_enum = SQLAlchemyEnumType(Color)
38 graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName")
39 assert isinstance(graphene_enum, type(Enum))
40 assert graphene_enum._meta.name == "Color"
41 assert graphene_enum._meta.enum is not Color
42 assert [
43 (key, value.value)
44 for key, value in graphene_enum._meta.enum.__members__.items()
45 ] == [("RED", 1), ("GREEN", 2), ("BLUE", 3)]
46
47
48 def test_convert_sa_enum_to_graphene_enum_based_on_list_named():
49 sa_enum = SQLAlchemyEnumType("red", "green", "blue", name="color_values")
50 graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName")
51 assert isinstance(graphene_enum, type(Enum))
52 assert graphene_enum._meta.name == "ColorValues"
53 assert [
54 (key, value.value)
55 for key, value in graphene_enum._meta.enum.__members__.items()
56 ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')]
57
58
59 def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed():
60 sa_enum = SQLAlchemyEnumType("red", "green", "blue")
61 graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName")
62 assert isinstance(graphene_enum, type(Enum))
63 assert graphene_enum._meta.name == "FallbackName"
64 assert [
65 (key, value.value)
66 for key, value in graphene_enum._meta.enum.__members__.items()
67 ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')]
68
69
70 def test_convert_sa_enum_to_graphene_enum_based_on_list_without_name():
71 sa_enum = SQLAlchemyEnumType("red", "green", "blue")
72 re_err = r"No type name specified for Enum\('red', 'green', 'blue'\)"
73 with pytest.raises(TypeError, match=re_err):
74 _convert_sa_to_graphene_enum(sa_enum)
75
76
77 def test_enum_for_field():
78 class PetType(SQLAlchemyObjectType):
79 class Meta:
80 model = Pet
81
82 enum = enum_for_field(PetType, 'pet_kind')
83 assert isinstance(enum, type(Enum))
84 assert enum._meta.name == "PetKind"
85 assert [
86 (key, value.value)
87 for key, value in enum._meta.enum.__members__.items()
88 ] == [("CAT", 'cat'), ("DOG", 'dog')]
89 enum2 = enum_for_field(PetType, 'pet_kind')
90 assert enum2 is enum
91 enum2 = PetType.enum_for_field('pet_kind')
92 assert enum2 is enum
93
94 enum = enum_for_field(PetType, 'hair_kind')
95 assert isinstance(enum, type(Enum))
96 assert enum._meta.name == "HairKind"
97 assert enum._meta.enum is HairKind
98 enum2 = PetType.enum_for_field('hair_kind')
99 assert enum2 is enum
100
101 re_err = r"Cannot get PetType\.other_kind"
102 with pytest.raises(TypeError, match=re_err):
103 enum_for_field(PetType, 'other_kind')
104 with pytest.raises(TypeError, match=re_err):
105 PetType.enum_for_field('other_kind')
106
107 re_err = r"PetType\.name does not map to enum column"
108 with pytest.raises(TypeError, match=re_err):
109 enum_for_field(PetType, 'name')
110 with pytest.raises(TypeError, match=re_err):
111 PetType.enum_for_field('name')
112
113 re_err = r"Expected a field name, but got: None"
114 with pytest.raises(TypeError, match=re_err):
115 enum_for_field(PetType, None)
116 with pytest.raises(TypeError, match=re_err):
117 PetType.enum_for_field(None)
118
119 re_err = "Expected SQLAlchemyObjectType, but got: None"
120 with pytest.raises(TypeError, match=re_err):
121 enum_for_field(None, 'other_kind')
00 import pytest
1 from promise import Promise
12
2 from graphene.relay import Connection
3 from graphene import NonNull, ObjectType
4 from graphene.relay import Connection, Node
35
4 from ..fields import SQLAlchemyConnectionField
6 from ..fields import (SQLAlchemyConnectionField,
7 UnsortedSQLAlchemyConnectionField)
58 from ..types import SQLAlchemyObjectType
6 from ..utils import sort_argument_for_model
7 from .models import Editor
9 from .models import Editor as EditorModel
810 from .models import Pet as PetModel
911
1012
1113 class Pet(SQLAlchemyObjectType):
1214 class Meta:
1315 model = PetModel
16 interfaces = (Node,)
1417
1518
16 class PetConn(Connection):
19 class Editor(SQLAlchemyObjectType):
1720 class Meta:
18 node = Pet
21 model = EditorModel
22
23 ##
24 # SQLAlchemyConnectionField
25 ##
26
27
28 def test_nonnull_sqlalachemy_connection():
29 field = SQLAlchemyConnectionField(NonNull(Pet.connection))
30 assert isinstance(field.type, NonNull)
31 assert issubclass(field.type.of_type, Connection)
32 assert field.type.of_type._meta.node is Pet
33
34
35 def test_required_sqlalachemy_connection():
36 field = SQLAlchemyConnectionField(Pet.connection, required=True)
37 assert isinstance(field.type, NonNull)
38 assert issubclass(field.type.of_type, Connection)
39 assert field.type.of_type._meta.node is Pet
40
41
42 def test_promise_connection_resolver():
43 def resolver(_obj, _info):
44 return Promise.resolve([])
45
46 result = UnsortedSQLAlchemyConnectionField.connection_resolver(
47 resolver, Pet.connection, Pet, None, None
48 )
49 assert isinstance(result, Promise)
50
51
52 def test_type_assert_sqlalchemy_object_type():
53 with pytest.raises(AssertionError, match="only accepts SQLAlchemyObjectType"):
54 SQLAlchemyConnectionField(ObjectType).type
55
56
57 def test_type_assert_object_has_connection():
58 with pytest.raises(AssertionError, match="doesn't have a connection"):
59 SQLAlchemyConnectionField(Editor).type
60
61 ##
62 # UnsortedSQLAlchemyConnectionField
63 ##
1964
2065
2166 def test_sort_added_by_default():
22 arg = SQLAlchemyConnectionField(PetConn)
23 assert "sort" in arg.args
24 assert arg.args["sort"] == sort_argument_for_model(PetModel)
67 field = SQLAlchemyConnectionField(Pet.connection)
68 assert "sort" in field.args
69 assert field.args["sort"] == Pet.sort_argument()
2570
2671
2772 def test_sort_can_be_removed():
28 arg = SQLAlchemyConnectionField(PetConn, sort=None)
29 assert "sort" not in arg.args
73 field = SQLAlchemyConnectionField(Pet.connection, sort=None)
74 assert "sort" not in field.args
3075
3176
3277 def test_custom_sort():
33 arg = SQLAlchemyConnectionField(PetConn, sort=sort_argument_for_model(Editor))
34 assert arg.args["sort"] == sort_argument_for_model(Editor)
78 field = SQLAlchemyConnectionField(Pet.connection, sort=Editor.sort_argument())
79 assert field.args["sort"] == Editor.sort_argument()
3580
3681
37 def test_init_raises():
38 with pytest.raises(Exception, match="Cannot create sort"):
82 def test_sort_init_raises():
83 with pytest.raises(TypeError, match="Cannot create sort"):
3984 SQLAlchemyConnectionField(Connection)
0 import pytest
1 from sqlalchemy import create_engine
2 from sqlalchemy.orm import scoped_session, sessionmaker
3
40 import graphene
5 from graphene.relay import Connection, Node
6
1 from graphene.relay import Node
2
3 from ..converter import convert_sqlalchemy_composite
74 from ..fields import SQLAlchemyConnectionField
8 from ..registry import reset_global_registry
9 from ..types import SQLAlchemyObjectType
10 from ..utils import sort_argument_for_model, sort_enum_for_model
11 from .models import Article, Base, Editor, Hairkind, Pet, Reporter
12
13 db = create_engine("sqlite:///test_sqlalchemy.sqlite3")
14
15
16 @pytest.yield_fixture(scope="function")
17 def session():
18 reset_global_registry()
19 connection = db.engine.connect()
20 transaction = connection.begin()
21 Base.metadata.create_all(connection)
22
23 # options = dict(bind=connection, binds={})
24 session_factory = sessionmaker(bind=connection)
25 session = scoped_session(session_factory)
26
27 yield session
28
29 # Finalize test here
30 transaction.rollback()
31 connection.close()
32 session.remove()
33
34
35 def setup_fixtures(session):
36 pet = Pet(name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG)
5 from ..types import ORMField, SQLAlchemyObjectType
6 from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter
7 from .utils import to_std_dicts
8
9
10 def add_test_data(session):
11 reporter = Reporter(
12 first_name='John', last_name='Doe', favorite_pet_kind='cat')
13 session.add(reporter)
14 pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT)
3715 session.add(pet)
38 reporter = Reporter(first_name="ABA", last_name="X")
39 session.add(reporter)
40 reporter2 = Reporter(first_name="ABO", last_name="Y")
41 session.add(reporter2)
42 article = Article(headline="Hi!")
16 pet.reporters.append(reporter)
17 article = Article(headline='Hi!')
4318 article.reporter = reporter
4419 session.add(article)
45 editor = Editor(name="John")
20 reporter = Reporter(
21 first_name='Jane', last_name='Roe', favorite_pet_kind='dog')
22 session.add(reporter)
23 pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG)
24 pet.reporters.append(reporter)
25 session.add(pet)
26 editor = Editor(name="Jack")
4627 session.add(editor)
4728 session.commit()
4829
4930
50 def test_should_query_well(session):
51 setup_fixtures(session)
31 def test_query_fields(session):
32 add_test_data(session)
33
34 @convert_sqlalchemy_composite.register(CompositeFullName)
35 def convert_composite_class(composite, registry):
36 return graphene.String()
5237
5338 class ReporterType(SQLAlchemyObjectType):
5439 class Meta:
5843 reporter = graphene.Field(ReporterType)
5944 reporters = graphene.List(ReporterType)
6045
61 def resolve_reporter(self, *args, **kwargs):
46 def resolve_reporter(self, _info):
6247 return session.query(Reporter).first()
6348
64 def resolve_reporters(self, *args, **kwargs):
49 def resolve_reporters(self, _info):
6550 return session.query(Reporter)
6651
6752 query = """
68 query ReporterQuery {
53 query {
6954 reporter {
70 firstName,
71 lastName,
72 email
55 firstName
56 columnProp
57 hybridProp
58 compositeProp
7359 }
7460 reporters {
7561 firstName
7763 }
7864 """
7965 expected = {
80 "reporter": {"firstName": "ABA", "lastName": "X", "email": None},
81 "reporters": [{"firstName": "ABA"}, {"firstName": "ABO"}],
66 "reporter": {
67 "firstName": "John",
68 "hybridProp": "John",
69 "columnProp": 2,
70 "compositeProp": "John Doe",
71 },
72 "reporters": [{"firstName": "John"}, {"firstName": "Jane"}],
8273 }
8374 schema = graphene.Schema(query=Query)
8475 result = schema.execute(query)
8576 assert not result.errors
86 assert result.data == expected
87
88
89 def test_should_query_enums(session):
90 setup_fixtures(session)
91
92 class PetType(SQLAlchemyObjectType):
93 class Meta:
94 model = Pet
95
96 class Query(graphene.ObjectType):
97 pet = graphene.Field(PetType)
98
99 def resolve_pet(self, *args, **kwargs):
100 return session.query(Pet).first()
101
102 query = """
103 query PetQuery {
104 pet {
105 name,
106 petKind
107 hairKind
108 }
109 }
110 """
111 expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}}
112 schema = graphene.Schema(query=Query)
113 result = schema.execute(query)
114 assert not result.errors
115 assert result.data == expected, result.data
116
117
118 def test_enum_parameter(session):
119 setup_fixtures(session)
120
121 class PetType(SQLAlchemyObjectType):
122 class Meta:
123 model = Pet
124
125 class Query(graphene.ObjectType):
126 pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type))
127
128 def resolve_pet(self, info, kind=None, *args, **kwargs):
129 query = session.query(Pet)
130 if kind:
131 query = query.filter(Pet.pet_kind == kind)
132 return query.first()
133
134 query = """
135 query PetQuery($kind: pet_kind) {
136 pet(kind: $kind) {
137 name,
138 petKind
139 hairKind
140 }
141 }
142 """
143 expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}}
144 schema = graphene.Schema(query=Query)
145 result = schema.execute(query, variables={"kind": "cat"})
146 assert not result.errors
147 assert result.data == {"pet": None}
148 result = schema.execute(query, variables={"kind": "dog"})
149 assert not result.errors
150 assert result.data == expected, result.data
151
152
153 def test_py_enum_parameter(session):
154 setup_fixtures(session)
155
156 class PetType(SQLAlchemyObjectType):
157 class Meta:
158 model = Pet
159
160 class Query(graphene.ObjectType):
161 pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['hair_kind'].type.of_type))
162
163 def resolve_pet(self, info, kind=None, *args, **kwargs):
164 query = session.query(Pet)
165 if kind:
166 # XXX Why kind passed in as a str instead of a Hairkind instance?
167 query = query.filter(Pet.hair_kind == Hairkind(kind))
168 return query.first()
169
170 query = """
171 query PetQuery($kind: Hairkind) {
172 pet(kind: $kind) {
173 name,
174 petKind
175 hairKind
176 }
177 }
178 """
179 expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}}
180 schema = graphene.Schema(query=Query)
181 result = schema.execute(query, variables={"kind": "SHORT"})
182 assert not result.errors
183 assert result.data == {"pet": None}
184 result = schema.execute(query, variables={"kind": "LONG"})
185 assert not result.errors
186 assert result.data == expected, result.data
187
188
189 def test_should_node(session):
190 setup_fixtures(session)
77 result = to_std_dicts(result.data)
78 assert result == expected
79
80
81 def test_query_node(session):
82 add_test_data(session)
19183
19284 class ReporterNode(SQLAlchemyObjectType):
19385 class Meta:
20395 model = Article
20496 interfaces = (Node,)
20597
206 # @classmethod
207 # def get_node(cls, id, info):
208 # return Article(id=1, headline='Article node')
209
210 class ArticleConnection(Connection):
211 class Meta:
212 node = ArticleNode
213
21498 class Query(graphene.ObjectType):
21599 node = Node.Field()
216100 reporter = graphene.Field(ReporterNode)
217 article = graphene.Field(ArticleNode)
218 all_articles = SQLAlchemyConnectionField(ArticleConnection)
219
220 def resolve_reporter(self, *args, **kwargs):
101 all_articles = SQLAlchemyConnectionField(ArticleNode.connection)
102
103 def resolve_reporter(self, _info):
221104 return session.query(Reporter).first()
222105
223 def resolve_article(self, *args, **kwargs):
224 return session.query(Article).first()
225
226 query = """
227 query ReporterQuery {
106 query = """
107 query {
228108 reporter {
229 id,
230 firstName,
109 id
110 firstName
231111 articles {
232112 edges {
233113 node {
235115 }
236116 }
237117 }
238 lastName,
239 email
240118 }
241119 allArticles {
242120 edges {
259137 expected = {
260138 "reporter": {
261139 "id": "UmVwb3J0ZXJOb2RlOjE=",
262 "firstName": "ABA",
263 "lastName": "X",
264 "email": None,
140 "firstName": "John",
265141 "articles": {"edges": [{"node": {"headline": "Hi!"}}]},
266142 },
267143 "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]},
270146 schema = graphene.Schema(query=Query)
271147 result = schema.execute(query, context_value={"session": session})
272148 assert not result.errors
273 assert result.data == expected
274
275
276 def test_should_custom_identifier(session):
277 setup_fixtures(session)
149 result = to_std_dicts(result.data)
150 assert result == expected
151
152
153 def test_orm_field(session):
154 add_test_data(session)
155
156 @convert_sqlalchemy_composite.register(CompositeFullName)
157 def convert_composite_class(composite, registry):
158 return graphene.String()
159
160 class ReporterType(SQLAlchemyObjectType):
161 class Meta:
162 model = Reporter
163 interfaces = (Node,)
164
165 first_name_v2 = ORMField(model_attr='first_name')
166 hybrid_prop_v2 = ORMField(model_attr='hybrid_prop')
167 column_prop_v2 = ORMField(model_attr='column_prop')
168 composite_prop = ORMField()
169 favorite_article_v2 = ORMField(model_attr='favorite_article')
170 articles_v2 = ORMField(model_attr='articles')
171
172 class ArticleType(SQLAlchemyObjectType):
173 class Meta:
174 model = Article
175 interfaces = (Node,)
176
177 class Query(graphene.ObjectType):
178 reporter = graphene.Field(ReporterType)
179
180 def resolve_reporter(self, _info):
181 return session.query(Reporter).first()
182
183 query = """
184 query {
185 reporter {
186 firstNameV2
187 hybridPropV2
188 columnPropV2
189 compositeProp
190 favoriteArticleV2 {
191 headline
192 }
193 articlesV2(first: 1) {
194 edges {
195 node {
196 headline
197 }
198 }
199 }
200 }
201 }
202 """
203 expected = {
204 "reporter": {
205 "firstNameV2": "John",
206 "hybridPropV2": "John",
207 "columnPropV2": 2,
208 "compositeProp": "John Doe",
209 "favoriteArticleV2": {"headline": "Hi!"},
210 "articlesV2": {"edges": [{"node": {"headline": "Hi!"}}]},
211 },
212 }
213 schema = graphene.Schema(query=Query)
214 result = schema.execute(query, context_value={"session": session})
215 assert not result.errors
216 result = to_std_dicts(result.data)
217 assert result == expected
218
219
220 def test_custom_identifier(session):
221 add_test_data(session)
278222
279223 class EditorNode(SQLAlchemyObjectType):
280224 class Meta:
281225 model = Editor
282226 interfaces = (Node,)
283227
284 class EditorConnection(Connection):
285 class Meta:
286 node = EditorNode
287
288228 class Query(graphene.ObjectType):
289229 node = Node.Field()
290 all_editors = SQLAlchemyConnectionField(EditorConnection)
291
292 query = """
293 query EditorQuery {
230 all_editors = SQLAlchemyConnectionField(EditorNode.connection)
231
232 query = """
233 query {
294234 allEditors {
295235 edges {
296236 node {
297 id,
237 id
298238 name
299239 }
300240 }
307247 }
308248 """
309249 expected = {
310 "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "John"}}]},
311 "node": {"name": "John"},
250 "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "Jack"}}]},
251 "node": {"name": "Jack"},
312252 }
313253
314254 schema = graphene.Schema(query=Query)
315255 result = schema.execute(query, context_value={"session": session})
316256 assert not result.errors
317 assert result.data == expected
318
319
320 def test_should_mutate_well(session):
321 setup_fixtures(session)
257 result = to_std_dicts(result.data)
258 assert result == expected
259
260
261 def test_mutation(session):
262 add_test_data(session)
322263
323264 class EditorNode(SQLAlchemyObjectType):
324265 class Meta:
363304 create_article = CreateArticle.Field()
364305
365306 query = """
366 mutation ArticleCreator {
307 mutation {
367308 createArticle(
368309 headline: "My Article"
369310 reporterId: "1"
384325 "ok": True,
385326 "article": {
386327 "headline": "My Article",
387 "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "ABA"},
328 "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "John"},
388329 },
389330 }
390331 }
392333 schema = graphene.Schema(query=Query, mutation=Mutation)
393334 result = schema.execute(query, context_value={"session": session})
394335 assert not result.errors
395 assert result.data == expected
396
397
398 def sort_setup(session):
399 pets = [
400 Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG),
401 Pet(id=22, name="Alf", pet_kind="cat", hair_kind=Hairkind.LONG),
402 Pet(id=3, name="Barf", pet_kind="dog", hair_kind=Hairkind.LONG),
403 ]
404 session.add_all(pets)
405 session.commit()
406
407
408 def test_sort(session):
409 sort_setup(session)
410
411 class PetNode(SQLAlchemyObjectType):
412 class Meta:
413 model = Pet
414 interfaces = (Node,)
415
416 class PetConnection(Connection):
417 class Meta:
418 node = PetNode
419
420 class Query(graphene.ObjectType):
421 defaultSort = SQLAlchemyConnectionField(PetConnection)
422 nameSort = SQLAlchemyConnectionField(PetConnection)
423 multipleSort = SQLAlchemyConnectionField(PetConnection)
424 descSort = SQLAlchemyConnectionField(PetConnection)
425 singleColumnSort = SQLAlchemyConnectionField(
426 PetConnection, sort=graphene.Argument(sort_enum_for_model(Pet))
427 )
428 noDefaultSort = SQLAlchemyConnectionField(
429 PetConnection, sort=sort_argument_for_model(Pet, False)
430 )
431 noSort = SQLAlchemyConnectionField(PetConnection, sort=None)
432
433 query = """
434 query sortTest {
435 defaultSort{
436 edges{
437 node{
438 id
439 }
440 }
441 }
442 nameSort(sort: name_asc){
443 edges{
444 node{
445 name
446 }
447 }
448 }
449 multipleSort(sort: [pet_kind_asc, name_desc]){
450 edges{
451 node{
452 name
453 petKind
454 }
455 }
456 }
457 descSort(sort: [name_desc]){
458 edges{
459 node{
460 name
461 }
462 }
463 }
464 singleColumnSort(sort: name_desc){
465 edges{
466 node{
467 name
468 }
469 }
470 }
471 noDefaultSort(sort: name_asc){
472 edges{
473 node{
474 name
475 }
476 }
477 }
478 }
479 """
480
481 def makeNodes(nodeList):
482 nodes = [{"node": item} for item in nodeList]
483 return {"edges": nodes}
484
485 expected = {
486 "defaultSort": makeNodes(
487 [{"id": "UGV0Tm9kZToy"}, {"id": "UGV0Tm9kZToz"}, {"id": "UGV0Tm9kZToyMg=="}]
488 ),
489 "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]),
490 "noDefaultSort": makeNodes(
491 [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]
492 ),
493 "multipleSort": makeNodes(
494 [
495 {"name": "Alf", "petKind": "cat"},
496 {"name": "Lassie", "petKind": "dog"},
497 {"name": "Barf", "petKind": "dog"},
498 ]
499 ),
500 "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]),
501 "singleColumnSort": makeNodes(
502 [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]
503 ),
504 } # yapf: disable
505
506 schema = graphene.Schema(query=Query)
507 result = schema.execute(query, context_value={"session": session})
508 assert not result.errors
509 assert result.data == expected
510
511 queryError = """
512 query sortTest {
513 singleColumnSort(sort: [pet_kind_asc, name_desc]){
514 edges{
515 node{
516 name
517 }
518 }
519 }
520 }
521 """
522 result = schema.execute(queryError, context_value={"session": session})
523 assert result.errors is not None
524
525 queryNoSort = """
526 query sortTest {
527 noDefaultSort{
528 edges{
529 node{
530 name
531 }
532 }
533 }
534 noSort{
535 edges{
536 node{
537 name
538 }
539 }
540 }
541 }
542 """
543
544 expectedNoSort = {
545 "noDefaultSort": makeNodes(
546 [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]
547 ),
548 "noSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]),
549 } # yapf: disable
550
551 result = schema.execute(queryNoSort, context_value={"session": session})
552 assert not result.errors
553 for key, value in result.data.items():
554 assert set(node["node"]["name"] for node in value["edges"]) == set(
555 node["node"]["name"] for node in expectedNoSort[key]["edges"]
556 )
336 result = to_std_dicts(result.data)
337 assert result == expected
0 import graphene
1
2 from ..types import SQLAlchemyObjectType
3 from .models import HairKind, Pet, Reporter
4 from .test_query import add_test_data, to_std_dicts
5
6
7 def test_query_pet_kinds(session):
8 add_test_data(session)
9
10 class PetType(SQLAlchemyObjectType):
11
12 class Meta:
13 model = Pet
14
15 class ReporterType(SQLAlchemyObjectType):
16 class Meta:
17 model = Reporter
18
19 class Query(graphene.ObjectType):
20 reporter = graphene.Field(ReporterType)
21 reporters = graphene.List(ReporterType)
22 pets = graphene.List(PetType, kind=graphene.Argument(
23 PetType.enum_for_field('pet_kind')))
24
25 def resolve_reporter(self, _info):
26 return session.query(Reporter).first()
27
28 def resolve_reporters(self, _info):
29 return session.query(Reporter)
30
31 def resolve_pets(self, _info, kind):
32 query = session.query(Pet)
33 if kind:
34 query = query.filter_by(pet_kind=kind)
35 return query
36
37 query = """
38 query ReporterQuery {
39 reporter {
40 firstName
41 lastName
42 email
43 favoritePetKind
44 pets {
45 name
46 petKind
47 }
48 }
49 reporters {
50 firstName
51 favoritePetKind
52 }
53 pets(kind: DOG) {
54 name
55 petKind
56 }
57 }
58 """
59 expected = {
60 'reporter': {
61 'firstName': 'John',
62 'lastName': 'Doe',
63 'email': None,
64 'favoritePetKind': 'CAT',
65 'pets': [{
66 'name': 'Garfield',
67 'petKind': 'CAT'
68 }]
69 },
70 'reporters': [{
71 'firstName': 'John',
72 'favoritePetKind': 'CAT',
73 }, {
74 'firstName': 'Jane',
75 'favoritePetKind': 'DOG',
76 }],
77 'pets': [{
78 'name': 'Lassie',
79 'petKind': 'DOG'
80 }]
81 }
82 schema = graphene.Schema(query=Query)
83 result = schema.execute(query)
84 assert not result.errors
85 assert result.data == expected
86
87
88 def test_query_more_enums(session):
89 add_test_data(session)
90
91 class PetType(SQLAlchemyObjectType):
92 class Meta:
93 model = Pet
94
95 class Query(graphene.ObjectType):
96 pet = graphene.Field(PetType)
97
98 def resolve_pet(self, _info):
99 return session.query(Pet).first()
100
101 query = """
102 query PetQuery {
103 pet {
104 name,
105 petKind
106 hairKind
107 }
108 }
109 """
110 expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}}
111 schema = graphene.Schema(query=Query)
112 result = schema.execute(query)
113 assert not result.errors
114 result = to_std_dicts(result.data)
115 assert result == expected
116
117
118 def test_enum_as_argument(session):
119 add_test_data(session)
120
121 class PetType(SQLAlchemyObjectType):
122 class Meta:
123 model = Pet
124
125 class Query(graphene.ObjectType):
126 pet = graphene.Field(
127 PetType,
128 kind=graphene.Argument(PetType.enum_for_field('pet_kind')))
129
130 def resolve_pet(self, info, kind=None):
131 query = session.query(Pet)
132 if kind:
133 query = query.filter(Pet.pet_kind == kind)
134 return query.first()
135
136 query = """
137 query PetQuery($kind: PetKind) {
138 pet(kind: $kind) {
139 name,
140 petKind
141 hairKind
142 }
143 }
144 """
145
146 schema = graphene.Schema(query=Query)
147 result = schema.execute(query, variables={"kind": "CAT"})
148 assert not result.errors
149 expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}}
150 assert result.data == expected
151 result = schema.execute(query, variables={"kind": "DOG"})
152 assert not result.errors
153 expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}}
154 result = to_std_dicts(result.data)
155 assert result == expected
156
157
158 def test_py_enum_as_argument(session):
159 add_test_data(session)
160
161 class PetType(SQLAlchemyObjectType):
162 class Meta:
163 model = Pet
164
165 class Query(graphene.ObjectType):
166 pet = graphene.Field(
167 PetType,
168 kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type),
169 )
170
171 def resolve_pet(self, _info, kind=None):
172 query = session.query(Pet)
173 if kind:
174 # enum arguments are expected to be strings, not PyEnums
175 query = query.filter(Pet.hair_kind == HairKind(kind))
176 return query.first()
177
178 query = """
179 query PetQuery($kind: HairKind) {
180 pet(kind: $kind) {
181 name,
182 petKind
183 hairKind
184 }
185 }
186 """
187
188 schema = graphene.Schema(query=Query)
189 result = schema.execute(query, variables={"kind": "SHORT"})
190 assert not result.errors
191 expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}}
192 assert result.data == expected
193 result = schema.execute(query, variables={"kind": "LONG"})
194 assert not result.errors
195 expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}}
196 result = to_std_dicts(result.data)
197 assert result == expected
00 import pytest
1 from sqlalchemy.types import Enum as SQLAlchemyEnum
2
3 from graphene import Enum as GrapheneEnum
14
25 from ..registry import Registry
36 from ..types import SQLAlchemyObjectType
7 from ..utils import EnumValue
48 from .models import Pet
59
610
7 def test_register_incorrect_objecttype():
8 reg = Registry()
9
10 class Spam:
11 pass
12
13 with pytest.raises(AssertionError) as excinfo:
14 reg.register(Spam)
15
16 assert "Only classes of type SQLAlchemyObjectType can be registered" in str(
17 excinfo.value
18 )
19
20
21 def test_register_objecttype():
11 def test_register_object_type():
2212 reg = Registry()
2313
2414 class PetType(SQLAlchemyObjectType):
2616 model = Pet
2717 registry = reg
2818
29 try:
30 reg.register(PetType)
31 except AssertionError:
32 pytest.fail("expected no AssertionError")
19 reg.register(PetType)
20 assert reg.get_type_for_model(Pet) is PetType
21
22
23 def test_register_incorrect_object_type():
24 reg = Registry()
25
26 class Spam:
27 pass
28
29 re_err = "Expected SQLAlchemyObjectType, but got: .*Spam"
30 with pytest.raises(TypeError, match=re_err):
31 reg.register(Spam)
32
33
34 def test_register_orm_field():
35 reg = Registry()
36
37 class PetType(SQLAlchemyObjectType):
38 class Meta:
39 model = Pet
40 registry = reg
41
42 reg.register_orm_field(PetType, "name", Pet.name)
43 assert reg.get_orm_field_for_graphene_field(PetType, "name") is Pet.name
44
45
46 def test_register_orm_field_incorrect_types():
47 reg = Registry()
48
49 class Spam:
50 pass
51
52 re_err = "Expected SQLAlchemyObjectType, but got: .*Spam"
53 with pytest.raises(TypeError, match=re_err):
54 reg.register_orm_field(Spam, "name", Pet.name)
55
56 class PetType(SQLAlchemyObjectType):
57 class Meta:
58 model = Pet
59 registry = reg
60
61 re_err = "Expected a field name, but got: .*Spam"
62 with pytest.raises(TypeError, match=re_err):
63 reg.register_orm_field(PetType, Spam, Pet.name)
64
65
66 def test_register_enum():
67 reg = Registry()
68
69 sa_enum = SQLAlchemyEnum("cat", "dog")
70 graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)])
71
72 reg.register_enum(sa_enum, graphene_enum)
73 assert reg.get_graphene_enum_for_sa_enum(sa_enum) is graphene_enum
74
75
76 def test_register_enum_incorrect_types():
77 reg = Registry()
78
79 sa_enum = SQLAlchemyEnum("cat", "dog")
80 graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)])
81
82 re_err = r"Expected Graphene Enum, but got: Enum\('cat', 'dog'\)"
83 with pytest.raises(TypeError, match=re_err):
84 reg.register_enum(sa_enum, sa_enum)
85
86 re_err = r"Expected SQLAlchemyEnumType, but got: .*PetKind.*"
87 with pytest.raises(TypeError, match=re_err):
88 reg.register_enum(graphene_enum, graphene_enum)
89
90
91 def test_register_sort_enum():
92 reg = Registry()
93
94 class PetType(SQLAlchemyObjectType):
95 class Meta:
96 model = Pet
97 registry = reg
98
99 sort_enum = GrapheneEnum(
100 "PetSort",
101 [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))],
102 )
103
104 reg.register_sort_enum(PetType, sort_enum)
105 assert reg.get_sort_enum_for_object_type(PetType) is sort_enum
106
107
108 def test_register_sort_enum_incorrect_types():
109 reg = Registry()
110
111 class PetType(SQLAlchemyObjectType):
112 class Meta:
113 model = Pet
114 registry = reg
115
116 sort_enum = GrapheneEnum(
117 "PetSort",
118 [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))],
119 )
120
121 re_err = r"Expected SQLAlchemyObjectType, but got: .*PetSort.*"
122 with pytest.raises(TypeError, match=re_err):
123 reg.register_sort_enum(sort_enum, sort_enum)
124
125 re_err = r"Expected Graphene Enum, but got: .*PetType.*"
126 with pytest.raises(TypeError, match=re_err):
127 reg.register_sort_enum(PetType, PetType)
+0
-49
graphene_sqlalchemy/tests/test_schema.py less more
0 from py.test import raises
1
2 from ..registry import Registry
3 from ..types import SQLAlchemyObjectType
4 from .models import Reporter
5
6
7 def test_should_raise_if_no_model():
8 with raises(Exception) as excinfo:
9
10 class Character1(SQLAlchemyObjectType):
11 pass
12
13 assert "valid SQLAlchemy Model" in str(excinfo.value)
14
15
16 def test_should_raise_if_model_is_invalid():
17 with raises(Exception) as excinfo:
18
19 class Character2(SQLAlchemyObjectType):
20 class Meta:
21 model = 1
22
23 assert "valid SQLAlchemy Model" in str(excinfo.value)
24
25
26 def test_should_map_fields_correctly():
27 class ReporterType2(SQLAlchemyObjectType):
28 class Meta:
29 model = Reporter
30 registry = Registry()
31
32 assert list(ReporterType2._meta.fields.keys()) == [
33 "id",
34 "first_name",
35 "last_name",
36 "email",
37 "pets",
38 "articles",
39 "favorite_article",
40 ]
41
42
43 def test_should_map_only_few_fields():
44 class Reporter2(SQLAlchemyObjectType):
45 class Meta:
46 model = Reporter
47 only_fields = ("id", "email")
48 assert list(Reporter2._meta.fields.keys()) == ["id", "email"]
0 import pytest
1 import sqlalchemy as sa
2
3 from graphene import Argument, Enum, List, ObjectType, Schema
4 from graphene.relay import Node
5
6 from ..fields import SQLAlchemyConnectionField
7 from ..types import SQLAlchemyObjectType
8 from ..utils import to_type_name
9 from .models import Base, HairKind, Pet
10 from .test_query import to_std_dicts
11
12
13 def add_pets(session):
14 pets = [
15 Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG),
16 Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG),
17 Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG),
18 ]
19 session.add_all(pets)
20 session.commit()
21
22
23 def test_sort_enum():
24 class PetType(SQLAlchemyObjectType):
25 class Meta:
26 model = Pet
27
28 sort_enum = PetType.sort_enum()
29 assert isinstance(sort_enum, type(Enum))
30 assert sort_enum._meta.name == "PetTypeSortEnum"
31 assert list(sort_enum._meta.enum.__members__) == [
32 "ID_ASC",
33 "ID_DESC",
34 "NAME_ASC",
35 "NAME_DESC",
36 "PET_KIND_ASC",
37 "PET_KIND_DESC",
38 "HAIR_KIND_ASC",
39 "HAIR_KIND_DESC",
40 "REPORTER_ID_ASC",
41 "REPORTER_ID_DESC",
42 ]
43 assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC"
44 assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC"
45 assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC"
46 assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC"
47
48
49 def test_sort_enum_with_custom_name():
50 class PetType(SQLAlchemyObjectType):
51 class Meta:
52 model = Pet
53
54 sort_enum = PetType.sort_enum(name="CustomSortName")
55 assert isinstance(sort_enum, type(Enum))
56 assert sort_enum._meta.name == "CustomSortName"
57
58
59 def test_sort_enum_cache():
60 class PetType(SQLAlchemyObjectType):
61 class Meta:
62 model = Pet
63
64 sort_enum = PetType.sort_enum()
65 sort_enum_2 = PetType.sort_enum()
66 assert sort_enum_2 is sort_enum
67 sort_enum_2 = PetType.sort_enum(name="PetTypeSortEnum")
68 assert sort_enum_2 is sort_enum
69 err_msg = "Sort enum for PetType has already been customized"
70 with pytest.raises(ValueError, match=err_msg):
71 PetType.sort_enum(name="CustomSortName")
72 with pytest.raises(ValueError, match=err_msg):
73 PetType.sort_enum(only_fields=["id"])
74 with pytest.raises(ValueError, match=err_msg):
75 PetType.sort_enum(only_indexed=True)
76 with pytest.raises(ValueError, match=err_msg):
77 PetType.sort_enum(get_symbol_name=lambda: "foo")
78
79
80 def test_sort_enum_with_excluded_field_in_object_type():
81 class PetType(SQLAlchemyObjectType):
82 class Meta:
83 model = Pet
84 exclude_fields = ["reporter_id"]
85
86 sort_enum = PetType.sort_enum()
87 assert list(sort_enum._meta.enum.__members__) == [
88 "ID_ASC",
89 "ID_DESC",
90 "NAME_ASC",
91 "NAME_DESC",
92 "PET_KIND_ASC",
93 "PET_KIND_DESC",
94 "HAIR_KIND_ASC",
95 "HAIR_KIND_DESC",
96 ]
97
98
99 def test_sort_enum_only_fields():
100 class PetType(SQLAlchemyObjectType):
101 class Meta:
102 model = Pet
103
104 sort_enum = PetType.sort_enum(only_fields=["id", "name"])
105 assert list(sort_enum._meta.enum.__members__) == [
106 "ID_ASC",
107 "ID_DESC",
108 "NAME_ASC",
109 "NAME_DESC",
110 ]
111
112
113 def test_sort_argument():
114 class PetType(SQLAlchemyObjectType):
115 class Meta:
116 model = Pet
117
118 sort_arg = PetType.sort_argument()
119 assert isinstance(sort_arg, Argument)
120
121 assert isinstance(sort_arg.type, List)
122 sort_enum = sort_arg.type._of_type
123 assert isinstance(sort_enum, type(Enum))
124 assert sort_enum._meta.name == "PetTypeSortEnum"
125 assert list(sort_enum._meta.enum.__members__) == [
126 "ID_ASC",
127 "ID_DESC",
128 "NAME_ASC",
129 "NAME_DESC",
130 "PET_KIND_ASC",
131 "PET_KIND_DESC",
132 "HAIR_KIND_ASC",
133 "HAIR_KIND_DESC",
134 "REPORTER_ID_ASC",
135 "REPORTER_ID_DESC",
136 ]
137 assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC"
138 assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC"
139 assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC"
140 assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC"
141
142 assert sort_arg.default_value == ["ID_ASC"]
143 assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC"
144
145
146 def test_sort_argument_with_excluded_fields_in_object_type():
147 class PetType(SQLAlchemyObjectType):
148 class Meta:
149 model = Pet
150 exclude_fields = ["hair_kind", "reporter_id"]
151
152 sort_arg = PetType.sort_argument()
153 sort_enum = sort_arg.type._of_type
154 assert list(sort_enum._meta.enum.__members__) == [
155 "ID_ASC",
156 "ID_DESC",
157 "NAME_ASC",
158 "NAME_DESC",
159 "PET_KIND_ASC",
160 "PET_KIND_DESC",
161 ]
162 assert sort_arg.default_value == ["ID_ASC"]
163
164
165 def test_sort_argument_only_fields():
166 class PetType(SQLAlchemyObjectType):
167 class Meta:
168 model = Pet
169 only_fields = ["id", "pet_kind"]
170
171 sort_arg = PetType.sort_argument()
172 sort_enum = sort_arg.type._of_type
173 assert list(sort_enum._meta.enum.__members__) == [
174 "ID_ASC",
175 "ID_DESC",
176 "PET_KIND_ASC",
177 "PET_KIND_DESC",
178 ]
179 assert sort_arg.default_value == ["ID_ASC"]
180
181
182 def test_sort_argument_for_multi_column_pk():
183 class MultiPkTestModel(Base):
184 __tablename__ = "multi_pk_test_table"
185 foo = sa.Column(sa.Integer, primary_key=True)
186 bar = sa.Column(sa.Integer, primary_key=True)
187
188 class MultiPkTestType(SQLAlchemyObjectType):
189 class Meta:
190 model = MultiPkTestModel
191
192 sort_arg = MultiPkTestType.sort_argument()
193 assert sort_arg.default_value == ["FOO_ASC", "BAR_ASC"]
194
195
196 def test_sort_argument_only_indexed():
197 class IndexedTestModel(Base):
198 __tablename__ = "indexed_test_table"
199 id = sa.Column(sa.Integer, primary_key=True)
200 foo = sa.Column(sa.Integer, index=False)
201 bar = sa.Column(sa.Integer, index=True)
202
203 class IndexedTestType(SQLAlchemyObjectType):
204 class Meta:
205 model = IndexedTestModel
206
207 sort_arg = IndexedTestType.sort_argument(only_indexed=True)
208 sort_enum = sort_arg.type._of_type
209 assert list(sort_enum._meta.enum.__members__) == [
210 "ID_ASC",
211 "ID_DESC",
212 "BAR_ASC",
213 "BAR_DESC",
214 ]
215 assert sort_arg.default_value == ["ID_ASC"]
216
217
218 def test_sort_argument_with_custom_symbol_names():
219 class PetType(SQLAlchemyObjectType):
220 class Meta:
221 model = Pet
222
223 def get_symbol_name(column_name, sort_asc=True):
224 return to_type_name(column_name) + ("Up" if sort_asc else "Down")
225
226 sort_arg = PetType.sort_argument(get_symbol_name=get_symbol_name)
227 sort_enum = sort_arg.type._of_type
228 assert list(sort_enum._meta.enum.__members__) == [
229 "IdUp",
230 "IdDown",
231 "NameUp",
232 "NameDown",
233 "PetKindUp",
234 "PetKindDown",
235 "HairKindUp",
236 "HairKindDown",
237 "ReporterIdUp",
238 "ReporterIdDown",
239 ]
240 assert sort_arg.default_value == ["IdUp"]
241
242
243 def test_sort_query(session):
244 add_pets(session)
245
246 class PetNode(SQLAlchemyObjectType):
247 class Meta:
248 model = Pet
249 interfaces = (Node,)
250
251 class Query(ObjectType):
252 defaultSort = SQLAlchemyConnectionField(PetNode.connection)
253 nameSort = SQLAlchemyConnectionField(PetNode.connection)
254 multipleSort = SQLAlchemyConnectionField(PetNode.connection)
255 descSort = SQLAlchemyConnectionField(PetNode.connection)
256 singleColumnSort = SQLAlchemyConnectionField(
257 PetNode.connection, sort=Argument(PetNode.sort_enum())
258 )
259 noDefaultSort = SQLAlchemyConnectionField(
260 PetNode.connection, sort=PetNode.sort_argument(has_default=False)
261 )
262 noSort = SQLAlchemyConnectionField(PetNode.connection, sort=None)
263
264 query = """
265 query sortTest {
266 defaultSort {
267 edges {
268 node {
269 name
270 }
271 }
272 }
273 nameSort(sort: NAME_ASC) {
274 edges {
275 node {
276 name
277 }
278 }
279 }
280 multipleSort(sort: [PET_KIND_ASC, NAME_DESC]) {
281 edges {
282 node {
283 name
284 petKind
285 }
286 }
287 }
288 descSort(sort: [NAME_DESC]) {
289 edges {
290 node {
291 name
292 }
293 }
294 }
295 singleColumnSort(sort: NAME_DESC) {
296 edges {
297 node {
298 name
299 }
300 }
301 }
302 noDefaultSort(sort: NAME_ASC) {
303 edges {
304 node {
305 name
306 }
307 }
308 }
309 }
310 """
311
312 def makeNodes(nodeList):
313 nodes = [{"node": item} for item in nodeList]
314 return {"edges": nodes}
315
316 expected = {
317 "defaultSort": makeNodes(
318 [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]
319 ),
320 "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]),
321 "noDefaultSort": makeNodes(
322 [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]
323 ),
324 "multipleSort": makeNodes(
325 [
326 {"name": "Alf", "petKind": "CAT"},
327 {"name": "Lassie", "petKind": "DOG"},
328 {"name": "Barf", "petKind": "DOG"},
329 ]
330 ),
331 "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]),
332 "singleColumnSort": makeNodes(
333 [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]
334 ),
335 } # yapf: disable
336
337 schema = Schema(query=Query)
338 result = schema.execute(query, context_value={"session": session})
339 assert not result.errors
340 result = to_std_dicts(result.data)
341 assert result == expected
342
343 queryError = """
344 query sortTest {
345 singleColumnSort(sort: [PET_KIND_ASC, NAME_DESC]) {
346 edges {
347 node {
348 name
349 }
350 }
351 }
352 }
353 """
354 result = schema.execute(queryError, context_value={"session": session})
355 assert result.errors is not None
356 assert '"sort" has invalid value' in result.errors[0].message
357
358 queryNoSort = """
359 query sortTest {
360 noDefaultSort {
361 edges {
362 node {
363 name
364 }
365 }
366 }
367 noSort {
368 edges {
369 node {
370 name
371 }
372 }
373 }
374 }
375 """
376
377 result = schema.execute(queryNoSort, context_value={"session": session})
378 assert not result.errors
379 # TODO: SQLite usually returns the results ordered by primary key,
380 # so we cannot test this way whether sorting actually happens or not.
381 # Also, no sort order is guaranteed by SQLite if "no order" by is used.
382 assert [node["node"]["name"] for node in result.data["noSort"]["edges"]] == [
383 node["node"]["name"] for node in result.data["noDefaultSort"]["edges"]
384 ]
0 from collections import OrderedDict
1
0 import mock
1 import pytest
22 import six # noqa F401
3 from promise import Promise
4
5 from graphene import (Connection, Field, Int, Interface, Node, ObjectType,
6 is_node)
7
3
4 from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull,
5 ObjectType, Schema, String)
6 from graphene.relay import Connection
7
8 from ..converter import convert_sqlalchemy_composite
89 from ..fields import (SQLAlchemyConnectionField,
9 UnsortedSQLAlchemyConnectionField,
10 UnsortedSQLAlchemyConnectionField, createConnectionField,
1011 registerConnectionFieldFactory,
1112 unregisterConnectionFieldFactory)
12 from ..registry import Registry
13 from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions
14 from .models import Article, Reporter
15
16 registry = Registry()
17
18
19 class Character(SQLAlchemyObjectType):
20 """Character description"""
21
22 class Meta:
23 model = Reporter
24 registry = registry
25
26
27 class Human(SQLAlchemyObjectType):
28 """Human description"""
29
30 pub_date = Int()
31
32 class Meta:
33 model = Article
34 exclude_fields = ("id",)
35 registry = registry
36 interfaces = (Node,)
37
38
39 def test_sqlalchemy_interface():
40 assert issubclass(Node, Interface)
41 assert issubclass(Node, Node)
42
43
44 # @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1))
45 # def test_sqlalchemy_get_node(get):
46 # human = Human.get_node(1, None)
47 # get.assert_called_with(id=1)
48 # assert human.id == 1
49
50
51 def test_objecttype_registered():
52 assert issubclass(Character, ObjectType)
53 assert Character._meta.model == Reporter
54 assert list(Character._meta.fields.keys()) == [
13 from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions
14 from .models import Article, CompositeFullName, Pet, Reporter
15
16
17 def test_should_raise_if_no_model():
18 re_err = r"valid SQLAlchemy Model"
19 with pytest.raises(Exception, match=re_err):
20 class Character1(SQLAlchemyObjectType):
21 pass
22
23
24 def test_should_raise_if_model_is_invalid():
25 re_err = r"valid SQLAlchemy Model"
26 with pytest.raises(Exception, match=re_err):
27 class Character(SQLAlchemyObjectType):
28 class Meta:
29 model = 1
30
31
32 def test_sqlalchemy_node(session):
33 class ReporterType(SQLAlchemyObjectType):
34 class Meta:
35 model = Reporter
36 interfaces = (Node,)
37
38 reporter_id_field = ReporterType._meta.fields["id"]
39 assert isinstance(reporter_id_field, GlobalID)
40
41 reporter = Reporter()
42 session.add(reporter)
43 session.commit()
44 info = mock.Mock(context={'session': session})
45 reporter_node = ReporterType.get_node(info, reporter.id)
46 assert reporter == reporter_node
47
48
49 def test_connection():
50 class ReporterType(SQLAlchemyObjectType):
51 class Meta:
52 model = Reporter
53 interfaces = (Node,)
54
55 assert issubclass(ReporterType.connection, Connection)
56
57
58 def test_sqlalchemy_default_fields():
59 @convert_sqlalchemy_composite.register(CompositeFullName)
60 def convert_composite_class(composite, registry):
61 return String()
62
63 class ReporterType(SQLAlchemyObjectType):
64 class Meta:
65 model = Reporter
66 interfaces = (Node,)
67
68 class ArticleType(SQLAlchemyObjectType):
69 class Meta:
70 model = Article
71 interfaces = (Node,)
72
73 assert list(ReporterType._meta.fields.keys()) == [
74 # Columns
75 "column_prop", # SQLAlchemy retuns column properties first
5576 "id",
5677 "first_name",
5778 "last_name",
5879 "email",
80 "favorite_pet_kind",
81 # Composite
82 "composite_prop",
83 # Hybrid
84 "hybrid_prop",
85 # Relationship
5986 "pets",
6087 "articles",
6188 "favorite_article",
6289 ]
6390
64
65 # def test_sqlalchemynode_idfield():
66 # idfield = Node._meta.fields_map['id']
67 # assert isinstance(idfield, GlobalIDField)
68
69
70 # def test_node_idfield():
71 # idfield = Human._meta.fields_map['id']
72 # assert isinstance(idfield, GlobalIDField)
73
74
75 def test_node_replacedfield():
76 idfield = Human._meta.fields["pub_date"]
77 assert isinstance(idfield, Field)
78 assert idfield.type == Int
79
80
81 def test_object_type():
82 class Human(SQLAlchemyObjectType):
83 """Human description"""
84
85 pub_date = Int()
86
91 # column
92 first_name_field = ReporterType._meta.fields['first_name']
93 assert first_name_field.type == String
94 assert first_name_field.description == "First name"
95
96 # column_property
97 column_prop_field = ReporterType._meta.fields['column_prop']
98 assert column_prop_field.type == Int
99 # "doc" is ignored by column_property
100 assert column_prop_field.description is None
101
102 # composite
103 full_name_field = ReporterType._meta.fields['composite_prop']
104 assert full_name_field.type == String
105 # "doc" is ignored by composite
106 assert full_name_field.description is None
107
108 # hybrid_property
109 hybrid_prop = ReporterType._meta.fields['hybrid_prop']
110 assert hybrid_prop.type == String
111 # "doc" is ignored by hybrid_property
112 assert hybrid_prop.description is None
113
114 # relationship
115 favorite_article_field = ReporterType._meta.fields['favorite_article']
116 assert isinstance(favorite_article_field, Dynamic)
117 assert favorite_article_field.type().type == ArticleType
118 assert favorite_article_field.type().description is None
119
120
121 def test_sqlalchemy_override_fields():
122 @convert_sqlalchemy_composite.register(CompositeFullName)
123 def convert_composite_class(composite, registry):
124 return String()
125
126 class ReporterMixin(object):
127 # columns
128 first_name = ORMField(required=True)
129 last_name = ORMField(description='Overridden')
130
131 class ReporterType(SQLAlchemyObjectType, ReporterMixin):
132 class Meta:
133 model = Reporter
134 interfaces = (Node,)
135
136 # columns
137 email = ORMField(deprecation_reason='Overridden')
138 email_v2 = ORMField(model_attr='email', type=Int)
139
140 # column_property
141 column_prop = ORMField(type=String)
142
143 # composite
144 composite_prop = ORMField()
145
146 # hybrid_property
147 hybrid_prop = ORMField(description='Overridden')
148
149 # relationships
150 favorite_article = ORMField(description='Overridden')
151 articles = ORMField(deprecation_reason='Overridden')
152 pets = ORMField(description='Overridden')
153
154 class ArticleType(SQLAlchemyObjectType):
87155 class Meta:
88156 model = Article
89 # exclude_fields = ('id', )
90 registry = registry
91 interfaces = (Node,)
92
93 assert issubclass(Human, ObjectType)
94 assert list(Human._meta.fields.keys()) == [
95 "id",
96 "headline",
97 "pub_date",
98 "reporter_id",
99 "reporter",
100 ]
101 assert is_node(Human)
102
103
104 # Test Custom SQLAlchemyObjectType Implementation
105 class CustomSQLAlchemyObjectType(SQLAlchemyObjectType):
106 class Meta:
107 abstract = True
108
109
110 class CustomCharacter(CustomSQLAlchemyObjectType):
111 """Character description"""
112
113 class Meta:
114 model = Reporter
115 registry = registry
116
117
118 def test_custom_objecttype_registered():
119 assert issubclass(CustomCharacter, ObjectType)
120 assert CustomCharacter._meta.model == Reporter
121 assert list(CustomCharacter._meta.fields.keys()) == [
122 "id",
157 interfaces = (Node,)
158
159 class PetType(SQLAlchemyObjectType):
160 class Meta:
161 model = Pet
162 interfaces = (Node,)
163 use_connection = False
164
165 assert list(ReporterType._meta.fields.keys()) == [
166 # Fields from ReporterMixin
123167 "first_name",
124168 "last_name",
169 # Fields from ReporterType
125170 "email",
171 "email_v2",
172 "column_prop",
173 "composite_prop",
174 "hybrid_prop",
175 "favorite_article",
176 "articles",
177 "pets",
178 # Then the automatic SQLAlchemy fields
179 "id",
180 "favorite_pet_kind",
181 ]
182
183 first_name_field = ReporterType._meta.fields['first_name']
184 assert isinstance(first_name_field.type, NonNull)
185 assert first_name_field.type.of_type == String
186 assert first_name_field.description == "First name"
187 assert first_name_field.deprecation_reason is None
188
189 last_name_field = ReporterType._meta.fields['last_name']
190 assert last_name_field.type == String
191 assert last_name_field.description == "Overridden"
192 assert last_name_field.deprecation_reason is None
193
194 email_field = ReporterType._meta.fields['email']
195 assert email_field.type == String
196 assert email_field.description == "Email"
197 assert email_field.deprecation_reason == "Overridden"
198
199 email_field_v2 = ReporterType._meta.fields['email_v2']
200 assert email_field_v2.type == Int
201 assert email_field_v2.description == "Email"
202 assert email_field_v2.deprecation_reason is None
203
204 hybrid_prop_field = ReporterType._meta.fields['hybrid_prop']
205 assert hybrid_prop_field.type == String
206 assert hybrid_prop_field.description == "Overridden"
207 assert hybrid_prop_field.deprecation_reason is None
208
209 column_prop_field_v2 = ReporterType._meta.fields['column_prop']
210 assert column_prop_field_v2.type == String
211 assert column_prop_field_v2.description is None
212 assert column_prop_field_v2.deprecation_reason is None
213
214 composite_prop_field = ReporterType._meta.fields['composite_prop']
215 assert composite_prop_field.type == String
216 assert composite_prop_field.description is None
217 assert composite_prop_field.deprecation_reason is None
218
219 favorite_article_field = ReporterType._meta.fields['favorite_article']
220 assert isinstance(favorite_article_field, Dynamic)
221 assert favorite_article_field.type().type == ArticleType
222 assert favorite_article_field.type().description == 'Overridden'
223
224 articles_field = ReporterType._meta.fields['articles']
225 assert isinstance(articles_field, Dynamic)
226 assert isinstance(articles_field.type(), UnsortedSQLAlchemyConnectionField)
227 assert articles_field.type().deprecation_reason == "Overridden"
228
229 pets_field = ReporterType._meta.fields['pets']
230 assert isinstance(pets_field, Dynamic)
231 assert isinstance(pets_field.type().type, List)
232 assert pets_field.type().type.of_type == PetType
233 assert pets_field.type().description == 'Overridden'
234
235
236 def test_invalid_model_attr():
237 err_msg = (
238 "Cannot map ORMField to a model attribute.\n"
239 "Field: 'ReporterType.first_name'"
240 )
241 with pytest.raises(ValueError, match=err_msg):
242 class ReporterType(SQLAlchemyObjectType):
243 class Meta:
244 model = Reporter
245
246 first_name = ORMField(model_attr='does_not_exist')
247
248
249 def test_only_fields():
250 class ReporterType(SQLAlchemyObjectType):
251 class Meta:
252 model = Reporter
253 only_fields = ("id", "last_name")
254
255 first_name = ORMField() # Takes precedence
256 last_name = ORMField() # Noop
257
258 assert list(ReporterType._meta.fields.keys()) == ["first_name", "last_name", "id"]
259
260
261 def test_exclude_fields():
262 class ReporterType(SQLAlchemyObjectType):
263 class Meta:
264 model = Reporter
265 exclude_fields = ("id", "first_name")
266
267 first_name = ORMField() # Takes precedence
268 last_name = ORMField() # Noop
269
270 assert list(ReporterType._meta.fields.keys()) == [
271 "first_name",
272 "last_name",
273 "column_prop",
274 "email",
275 "favorite_pet_kind",
276 "composite_prop",
277 "hybrid_prop",
126278 "pets",
127279 "articles",
128280 "favorite_article",
129281 ]
130282
131283
284 def test_only_and_exclude_fields():
285 re_err = r"'only_fields' and 'exclude_fields' cannot be both set"
286 with pytest.raises(Exception, match=re_err):
287 class ReporterType(SQLAlchemyObjectType):
288 class Meta:
289 model = Reporter
290 only_fields = ("id", "last_name")
291 exclude_fields = ("id", "last_name")
292
293
294 def test_sqlalchemy_redefine_field():
295 class ReporterType(SQLAlchemyObjectType):
296 class Meta:
297 model = Reporter
298
299 first_name = Int()
300
301 first_name_field = ReporterType._meta.fields["first_name"]
302 assert isinstance(first_name_field, Field)
303 assert first_name_field.type == Int
304
305
306 def test_resolvers(session):
307 """Test that the correct resolver functions are called"""
308
309 class ReporterMixin(object):
310 def resolve_id(root, _info):
311 return 'ID'
312
313 class ReporterType(ReporterMixin, SQLAlchemyObjectType):
314 class Meta:
315 model = Reporter
316
317 email = ORMField()
318 email_v2 = ORMField(model_attr='email')
319 favorite_pet_kind = Field(String)
320 favorite_pet_kind_v2 = Field(String)
321
322 def resolve_last_name(root, _info):
323 return root.last_name.upper()
324
325 def resolve_email_v2(root, _info):
326 return root.email + '_V2'
327
328 def resolve_favorite_pet_kind_v2(root, _info):
329 return str(root.favorite_pet_kind) + '_V2'
330
331 class Query(ObjectType):
332 reporter = Field(ReporterType)
333
334 def resolve_reporter(self, _info):
335 return session.query(Reporter).first()
336
337 reporter = Reporter(first_name='first_name', last_name='last_name', email='email', favorite_pet_kind='cat')
338 session.add(reporter)
339 session.commit()
340
341 schema = Schema(query=Query)
342 result = schema.execute("""
343 query {
344 reporter {
345 id
346 firstName
347 lastName
348 email
349 emailV2
350 favoritePetKind
351 favoritePetKindV2
352 }
353 }
354 """)
355
356 assert not result.errors
357 # Custom resolver on a base class
358 assert result.data['reporter']['id'] == 'ID'
359 # Default field + default resolver
360 assert result.data['reporter']['firstName'] == 'first_name'
361 # Default field + custom resolver
362 assert result.data['reporter']['lastName'] == 'LAST_NAME'
363 # ORMField + default resolver
364 assert result.data['reporter']['email'] == 'email'
365 # ORMField + custom resolver
366 assert result.data['reporter']['emailV2'] == 'email_V2'
367 # Field + default resolver
368 assert result.data['reporter']['favoritePetKind'] == 'cat'
369 # Field + custom resolver
370 assert result.data['reporter']['favoritePetKindV2'] == 'cat_V2'
371
372
373 # Test Custom SQLAlchemyObjectType Implementation
374
375 def test_custom_objecttype_registered():
376 class CustomSQLAlchemyObjectType(SQLAlchemyObjectType):
377 class Meta:
378 abstract = True
379
380 class CustomReporterType(CustomSQLAlchemyObjectType):
381 class Meta:
382 model = Reporter
383
384 assert issubclass(CustomReporterType, ObjectType)
385 assert CustomReporterType._meta.model == Reporter
386 assert len(CustomReporterType._meta.fields) == 11
387
388
132389 # Test Custom SQLAlchemyObjectType with Custom Options
133 class CustomOptions(SQLAlchemyObjectTypeOptions):
134 custom_option = None
135 custom_fields = None
136
137
138 class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType):
139 class Meta:
140 abstract = True
141
142 @classmethod
143 def __init_subclass_with_meta__(
144 cls, custom_option=None, custom_fields=None, **options
145 ):
146 _meta = CustomOptions(cls)
147 _meta.custom_option = custom_option
148 _meta.fields = custom_fields
149 super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__(
150 _meta=_meta, **options
151 )
152
153
154 class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions):
155 class Meta:
156 model = Reporter
157 custom_option = "custom_option"
158 custom_fields = OrderedDict([("custom_field", Field(Int()))])
159
160
161390 def test_objecttype_with_custom_options():
391 class CustomOptions(SQLAlchemyObjectTypeOptions):
392 custom_option = None
393
394 class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType):
395 class Meta:
396 abstract = True
397
398 @classmethod
399 def __init_subclass_with_meta__(cls, custom_option=None, **options):
400 _meta = CustomOptions(cls)
401 _meta.custom_option = custom_option
402 super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__(
403 _meta=_meta, **options
404 )
405
406 class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions):
407 class Meta:
408 model = Reporter
409 custom_option = "custom_option"
410
162411 assert issubclass(ReporterWithCustomOptions, ObjectType)
163412 assert ReporterWithCustomOptions._meta.model == Reporter
164 assert list(ReporterWithCustomOptions._meta.fields.keys()) == [
165 "custom_field",
166 "id",
167 "first_name",
168 "last_name",
169 "email",
170 "pets",
171 "articles",
172 "favorite_article",
173 ]
174413 assert ReporterWithCustomOptions._meta.custom_option == "custom_option"
175 assert isinstance(ReporterWithCustomOptions._meta.fields["custom_field"].type, Int)
176
177
178 def test_promise_connection_resolver():
179 class TestConnection(Connection):
180 class Meta:
181 node = ReporterWithCustomOptions
182
183 def resolver(*args, **kwargs):
184 return Promise.resolve([])
185
186 result = SQLAlchemyConnectionField.connection_resolver(
187 resolver, TestConnection, ReporterWithCustomOptions, None, None
188 )
189 assert result is not None
190414
191415
192416 # Tests for connection_field_factory
196420
197421
198422 def test_default_connection_field_factory():
199 _registry = Registry()
200
201 class ReporterType(SQLAlchemyObjectType):
202 class Meta:
203 model = Reporter
204 registry = _registry
423 class ReporterType(SQLAlchemyObjectType):
424 class Meta:
425 model = Reporter
205426 interfaces = (Node,)
206427
207428 class ArticleType(SQLAlchemyObjectType):
208429 class Meta:
209430 model = Article
210 registry = _registry
211431 interfaces = (Node,)
212432
213433 assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField)
214434
215435
216 def test_register_connection_field_factory():
436 def test_custom_connection_field_factory():
217437 def test_connection_field_factory(relationship, registry):
218438 model = relationship.mapper.entity
219439 _type = registry.get_type_for_model(model)
220440 return _TestSQLAlchemyConnectionField(_type._meta.connection)
221441
222 _registry = Registry()
223
224 class ReporterType(SQLAlchemyObjectType):
225 class Meta:
226 model = Reporter
227 registry = _registry
442 class ReporterType(SQLAlchemyObjectType):
443 class Meta:
444 model = Reporter
228445 interfaces = (Node,)
229446 connection_field_factory = test_connection_field_factory
230447
231448 class ArticleType(SQLAlchemyObjectType):
232449 class Meta:
233450 model = Article
234 registry = _registry
235451 interfaces = (Node,)
236452
237453 assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)
238454
239455
240456 def test_deprecated_registerConnectionFieldFactory():
241 registerConnectionFieldFactory(_TestSQLAlchemyConnectionField)
242
243 _registry = Registry()
244
245 class ReporterType(SQLAlchemyObjectType):
246 class Meta:
247 model = Reporter
248 registry = _registry
249 interfaces = (Node,)
250
251 class ArticleType(SQLAlchemyObjectType):
252 class Meta:
253 model = Article
254 registry = _registry
255 interfaces = (Node,)
256
257 assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)
457 with pytest.warns(DeprecationWarning):
458 registerConnectionFieldFactory(_TestSQLAlchemyConnectionField)
459
460 class ReporterType(SQLAlchemyObjectType):
461 class Meta:
462 model = Reporter
463 interfaces = (Node,)
464
465 class ArticleType(SQLAlchemyObjectType):
466 class Meta:
467 model = Article
468 interfaces = (Node,)
469
470 assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)
258471
259472
260473 def test_deprecated_unregisterConnectionFieldFactory():
261 registerConnectionFieldFactory(_TestSQLAlchemyConnectionField)
262 unregisterConnectionFieldFactory()
263
264 _registry = Registry()
265
266 class ReporterType(SQLAlchemyObjectType):
267 class Meta:
268 model = Reporter
269 registry = _registry
270 interfaces = (Node,)
271
272 class ArticleType(SQLAlchemyObjectType):
273 class Meta:
274 model = Article
275 registry = _registry
276 interfaces = (Node,)
277
278 assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)
474 with pytest.warns(DeprecationWarning):
475 registerConnectionFieldFactory(_TestSQLAlchemyConnectionField)
476 unregisterConnectionFieldFactory()
477
478 class ReporterType(SQLAlchemyObjectType):
479 class Meta:
480 model = Reporter
481 interfaces = (Node,)
482
483 class ArticleType(SQLAlchemyObjectType):
484 class Meta:
485 model = Article
486 interfaces = (Node,)
487
488 assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)
489
490
491 def test_deprecated_createConnectionField():
492 with pytest.warns(DeprecationWarning):
493 createConnectionField(None)
0 import pytest
01 import sqlalchemy as sa
12
23 from graphene import Enum, List, ObjectType, Schema, String
34
4 from ..utils import get_session, sort_argument_for_model, sort_enum_for_model
5 from .models import Editor, Pet
5 from ..utils import (get_session, sort_argument_for_model, sort_enum_for_model,
6 to_enum_value_name, to_type_name)
7 from .models import Base, Editor, Pet
68
79
810 def test_get_session():
2628 assert result.data["x"] == session
2729
2830
31 def test_to_type_name():
32 assert to_type_name("make_camel_case") == "MakeCamelCase"
33 assert to_type_name("AlreadyCamelCase") == "AlreadyCamelCase"
34 assert to_type_name("A_Snake_and_a_Camel") == "ASnakeAndACamel"
35
36
37 def test_to_enum_value_name():
38 assert to_enum_value_name("make_enum_value_name") == "MAKE_ENUM_VALUE_NAME"
39 assert to_enum_value_name("makeEnumValueName") == "MAKE_ENUM_VALUE_NAME"
40 assert to_enum_value_name("HTTPStatus400Message") == "HTTP_STATUS400_MESSAGE"
41 assert to_enum_value_name("ALREADY_ENUM_VALUE_NAME") == "ALREADY_ENUM_VALUE_NAME"
42
43
44 # test deprecated sort enum utility functions
45
46
2947 def test_sort_enum_for_model():
30 enum = sort_enum_for_model(Pet)
48 with pytest.warns(DeprecationWarning):
49 enum = sort_enum_for_model(Pet)
3150 assert isinstance(enum, type(Enum))
3251 assert str(enum) == "PetSortEnum"
3352 for col in sa.inspect(Pet).columns:
3655
3756
3857 def test_sort_enum_for_model_custom_naming():
39 enum = sort_enum_for_model(Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D"))
58 with pytest.warns(DeprecationWarning):
59 enum = sort_enum_for_model(
60 Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D")
61 )
4062 assert str(enum) == "Foo"
4163 for col in sa.inspect(Pet).columns:
4264 assert hasattr(enum, col.name.upper() + "A")
4466
4567
4668 def test_enum_cache():
47 assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor)
69 with pytest.warns(DeprecationWarning):
70 assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor)
4871
4972
5073 def test_sort_argument_for_model():
51 arg = sort_argument_for_model(Pet)
74 with pytest.warns(DeprecationWarning):
75 arg = sort_argument_for_model(Pet)
5276
5377 assert isinstance(arg.type, List)
5478 assert arg.default_value == [Pet.id.name + "_asc"]
55 assert arg.type.of_type == sort_enum_for_model(Pet)
79 with pytest.warns(DeprecationWarning):
80 assert arg.type.of_type is sort_enum_for_model(Pet)
5681
5782
5883 def test_sort_argument_for_model_no_default():
59 arg = sort_argument_for_model(Pet, False)
84 with pytest.warns(DeprecationWarning):
85 arg = sort_argument_for_model(Pet, False)
6086
6187 assert arg.default_value is None
6288
6389
6490 def test_sort_argument_for_model_multiple_pk():
65 Base = sa.ext.declarative.declarative_base()
66
6791 class MultiplePK(Base):
6892 foo = sa.Column(sa.Integer, primary_key=True)
6993 bar = sa.Column(sa.Integer, primary_key=True)
7094 __tablename__ = "MultiplePK"
7195
72 arg = sort_argument_for_model(MultiplePK)
96 with pytest.warns(DeprecationWarning):
97 arg = sort_argument_for_model(MultiplePK)
7398 assert set(arg.default_value) == set(
7499 (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc")
75100 )
0 import pkg_resources
1
2
3 def to_std_dicts(value):
4 """Convert nested ordered dicts to normal dicts for better comparison."""
5 if isinstance(value, dict):
6 return {k: to_std_dicts(v) for k, v in value.items()}
7 elif isinstance(value, list):
8 return [to_std_dicts(v) for v in value]
9 else:
10 return value
11
12
13 def is_sqlalchemy_version_less_than(version_string):
14 """Check the installed SQLAlchemy version"""
15 return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string)
11
22 import sqlalchemy
33 from sqlalchemy.ext.hybrid import hybrid_property
4 from sqlalchemy.inspection import inspect as sqlalchemyinspect
4 from sqlalchemy.orm import (ColumnProperty, CompositeProperty,
5 RelationshipProperty)
56 from sqlalchemy.orm.exc import NoResultFound
67
7 from graphene import Field # , annotate, ResolveInfo
8 from graphene import Field
89 from graphene.relay import Connection, Node
910 from graphene.types.objecttype import ObjectType, ObjectTypeOptions
1011 from graphene.types.utils import yank_fields_from_attrs
12 from graphene.utils.orderedtype import OrderedType
1113
1214 from .converter import (convert_sqlalchemy_column,
1315 convert_sqlalchemy_composite,
1416 convert_sqlalchemy_hybrid_method,
1517 convert_sqlalchemy_relationship)
16 from .fields import default_connection_field_factory
18 from .enums import (enum_for_field, sort_argument_for_object_type,
19 sort_enum_for_object_type)
1720 from .registry import Registry, get_global_registry
21 from .resolvers import get_attr_resolver, get_custom_resolver
1822 from .utils import get_query, is_mapped_class, is_mapped_instance
1923
2024
21 def construct_fields(model, registry, only_fields, exclude_fields, connection_field_factory):
22 inspected_model = sqlalchemyinspect(model)
23
25 class ORMField(OrderedType):
26 def __init__(
27 self,
28 model_attr=None,
29 type=None,
30 required=None,
31 description=None,
32 deprecation_reason=None,
33 batching=None,
34 _creation_counter=None,
35 **field_kwargs
36 ):
37 """
38 Use this to override fields automatically generated by SQLAlchemyObjectType.
39 Unless specified, options will default to SQLAlchemyObjectType usual behavior
40 for the given SQLAlchemy model property.
41
42 Usage:
43 class MyModel(Base):
44 id = Column(Integer(), primary_key=True)
45 name = Column(String)
46
47 class MyType(SQLAlchemyObjectType):
48 class Meta:
49 model = MyModel
50
51 id = ORMField(type=graphene.Int)
52 name = ORMField(required=True)
53
54 -> MyType.id will be of type Int (vs ID).
55 -> MyType.name will be of type NonNull(String) (vs String).
56
57 :param str model_attr:
58 Name of the SQLAlchemy model attribute used to resolve this field.
59 Default to the name of the attribute referencing the ORMField.
60 :param type:
61 Default to the type mapping in converter.py.
62 :param str description:
63 Default to the `doc` attribute of the SQLAlchemy column property.
64 :param bool required:
65 Default to the opposite of the `nullable` attribute of the SQLAlchemy column property.
66 :param str description:
67 Same behavior as in graphene.Field. Defaults to None.
68 :param str deprecation_reason:
69 Same behavior as in graphene.Field. Defaults to None.
70 :param bool batching:
71 Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`.
72 :param int _creation_counter:
73 Same behavior as in graphene.Field.
74 """
75 super(ORMField, self).__init__(_creation_counter=_creation_counter)
76 # The is only useful for documentation and auto-completion
77 common_kwargs = {
78 'model_attr': model_attr,
79 'type': type,
80 'required': required,
81 'description': description,
82 'deprecation_reason': deprecation_reason,
83 'batching': batching,
84 }
85 common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None}
86 self.kwargs = field_kwargs
87 self.kwargs.update(common_kwargs)
88
89
90 def construct_fields(
91 obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory
92 ):
93 """
94 Construct all the fields for a SQLAlchemyObjectType.
95 The main steps are:
96 - Gather all the relevant attributes from the SQLAlchemy model
97 - Gather all the ORM fields defined on the type
98 - Merge in overrides and build up all the fields
99
100 :param SQLAlchemyObjectType obj_type:
101 :param model: the SQLAlchemy model
102 :param Registry registry:
103 :param tuple[string] only_fields:
104 :param tuple[string] exclude_fields:
105 :param bool batching:
106 :param function|None connection_field_factory:
107 :rtype: OrderedDict[str, graphene.Field]
108 """
109 inspected_model = sqlalchemy.inspect(model)
110 # Gather all the relevant attributes from the SQLAlchemy model in order
111 all_model_attrs = OrderedDict(
112 inspected_model.column_attrs.items() +
113 inspected_model.composites.items() +
114 [(name, item) for name, item in inspected_model.all_orm_descriptors.items()
115 if isinstance(item, hybrid_property)] +
116 inspected_model.relationships.items()
117 )
118
119 # Filter out excluded fields
120 auto_orm_field_names = []
121 for attr_name, attr in all_model_attrs.items():
122 if (only_fields and attr_name not in only_fields) or (attr_name in exclude_fields):
123 continue
124 auto_orm_field_names.append(attr_name)
125
126 # Gather all the ORM fields defined on the type
127 custom_orm_fields_items = [
128 (attn_name, attr)
129 for base in reversed(obj_type.__mro__)
130 for attn_name, attr in base.__dict__.items()
131 if isinstance(attr, ORMField)
132 ]
133 custom_orm_fields_items = sorted(custom_orm_fields_items, key=lambda item: item[1])
134
135 # Set the model_attr if not set
136 for orm_field_name, orm_field in custom_orm_fields_items:
137 attr_name = orm_field.kwargs.get('model_attr', orm_field_name)
138 if attr_name not in all_model_attrs:
139 raise ValueError((
140 "Cannot map ORMField to a model attribute.\n"
141 "Field: '{}.{}'"
142 ).format(obj_type.__name__, orm_field_name,))
143 orm_field.kwargs['model_attr'] = attr_name
144
145 # Merge automatic fields with custom ORM fields
146 orm_fields = OrderedDict(custom_orm_fields_items)
147 for orm_field_name in auto_orm_field_names:
148 if orm_field_name in orm_fields:
149 continue
150 orm_fields[orm_field_name] = ORMField(model_attr=orm_field_name)
151
152 # Build all the field dictionary
24153 fields = OrderedDict()
25
26 for name, column in inspected_model.columns.items():
27 is_not_in_only = only_fields and name not in only_fields
28 # is_already_created = name in options.fields
29 is_excluded = name in exclude_fields # or is_already_created
30 if is_not_in_only or is_excluded:
31 # We skip this field if we specify only_fields and is not
32 # in there. Or when we exclude this field in exclude_fields
33 continue
34 converted_column = convert_sqlalchemy_column(column, registry)
35 fields[name] = converted_column
36
37 for name, composite in inspected_model.composites.items():
38 is_not_in_only = only_fields and name not in only_fields
39 # is_already_created = name in options.fields
40 is_excluded = name in exclude_fields # or is_already_created
41 if is_not_in_only or is_excluded:
42 # We skip this field if we specify only_fields and is not
43 # in there. Or when we exclude this field in exclude_fields
44 continue
45 converted_composite = convert_sqlalchemy_composite(composite, registry)
46 fields[name] = converted_composite
47
48 for hybrid_item in inspected_model.all_orm_descriptors:
49
50 if type(hybrid_item) == hybrid_property:
51 name = hybrid_item.__name__
52
53 is_not_in_only = only_fields and name not in only_fields
54 # is_already_created = name in options.fields
55 is_excluded = name in exclude_fields # or is_already_created
56
57 if is_not_in_only or is_excluded:
58 # We skip this field if we specify only_fields and is not
59 # in there. Or when we exclude this field in exclude_fields
60 continue
61
62 converted_hybrid_property = convert_sqlalchemy_hybrid_method(hybrid_item)
63 fields[name] = converted_hybrid_property
64
65 # Get all the columns for the relationships on the model
66 for relationship in inspected_model.relationships:
67 is_not_in_only = only_fields and relationship.key not in only_fields
68 # is_already_created = relationship.key in options.fields
69 is_excluded = relationship.key in exclude_fields # or is_already_created
70 if is_not_in_only or is_excluded:
71 # We skip this field if we specify only_fields and is not
72 # in there. Or when we exclude this field in exclude_fields
73 continue
74 converted_relationship = convert_sqlalchemy_relationship(relationship, registry, connection_field_factory)
75 name = relationship.key
76 fields[name] = converted_relationship
154 for orm_field_name, orm_field in orm_fields.items():
155 attr_name = orm_field.kwargs.pop('model_attr')
156 attr = all_model_attrs[attr_name]
157 resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver(obj_type, attr_name)
158
159 if isinstance(attr, ColumnProperty):
160 field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs)
161 elif isinstance(attr, RelationshipProperty):
162 batching_ = orm_field.kwargs.pop('batching', batching)
163 field = convert_sqlalchemy_relationship(
164 attr, obj_type, connection_field_factory, batching_, orm_field_name, **orm_field.kwargs)
165 elif isinstance(attr, CompositeProperty):
166 if attr_name != orm_field_name or orm_field.kwargs:
167 # TODO Add a way to override composite property fields
168 raise ValueError(
169 "ORMField kwargs for composite fields must be empty. "
170 "Field: {}.{}".format(obj_type.__name__, orm_field_name))
171 field = convert_sqlalchemy_composite(attr, registry, resolver)
172 elif isinstance(attr, hybrid_property):
173 field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs)
174 else:
175 raise Exception('Property type is not supported') # Should never happen
176
177 registry.register_orm_field(obj_type, orm_field_name, attr)
178 fields[orm_field_name] = field
77179
78180 return fields
79181
99201 use_connection=None,
100202 interfaces=(),
101203 id=None,
102 connection_field_factory=default_connection_field_factory,
204 batching=False,
205 connection_field_factory=None,
103206 _meta=None,
104207 **options
105208 ):
115218 'Registry, received "{}".'
116219 ).format(cls.__name__, registry)
117220
221 if only_fields and exclude_fields:
222 raise ValueError("The options 'only_fields' and 'exclude_fields' cannot be both set on the same type.")
223
118224 sqla_fields = yank_fields_from_attrs(
119225 construct_fields(
226 obj_type=cls,
120227 model=model,
121228 registry=registry,
122229 only_fields=only_fields,
123230 exclude_fields=exclude_fields,
124 connection_field_factory=connection_field_factory
231 batching=batching,
232 connection_field_factory=connection_field_factory,
125233 ),
126 _as=Field
234 _as=Field,
235 sort=False,
127236 )
128237
129238 if use_connection is None and interfaces:
158267
159268 _meta.connection = connection
160269 _meta.id = id or "id"
270
271 cls.connection = connection # Public way to get the connection
161272
162273 super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__(
163274 _meta=_meta, interfaces=interfaces, **options
190301 # graphene_type = info.parent_type.graphene_type
191302 keys = self.__mapper__.primary_key_from_instance(self)
192303 return tuple(keys) if len(keys) > 1 else keys[0]
304
305 @classmethod
306 def enum_for_field(cls, field_name):
307 return enum_for_field(cls, field_name)
308
309 sort_enum = classmethod(sort_enum_for_object_type)
310
311 sort_argument = classmethod(sort_argument_for_object_type)
0 import re
1 import warnings
2
03 from sqlalchemy.exc import ArgumentError
1 from sqlalchemy.inspection import inspect
24 from sqlalchemy.orm import class_mapper, object_mapper
35 from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError
4
5 from graphene import Argument, Enum, List
66
77
88 def get_session(context):
4040 return True
4141
4242
43 def _symbol_name(column_name, is_asc):
44 return column_name + ("_asc" if is_asc else "_desc")
43 def to_type_name(name):
44 """Convert the given name to a GraphQL type name."""
45 return "".join(part[:1].upper() + part[1:] for part in name.split("_"))
46
47
48 _re_enum_value_name_1 = re.compile("(.)([A-Z][a-z]+)")
49 _re_enum_value_name_2 = re.compile("([a-z0-9])([A-Z])")
50
51
52 def to_enum_value_name(name):
53 """Convert the given name to a GraphQL enum value name."""
54 return _re_enum_value_name_2.sub(
55 r"\1_\2", _re_enum_value_name_1.sub(r"\1_\2", name)
56 ).upper()
4557
4658
4759 class EnumValue(str):
48 """Subclass of str that stores a string and an arbitrary value in the "value" property"""
60 """String that has an additional value attached.
4961
50 def __new__(cls, str_value, value):
51 return super(EnumValue, cls).__new__(cls, str_value)
62 This is used to attach SQLAlchemy model columns to Enum symbols.
63 """
5264
53 def __init__(self, str_value, value):
65 def __new__(cls, s, value):
66 return super(EnumValue, cls).__new__(cls, s)
67
68 def __init__(self, _s, value):
5469 super(EnumValue, self).__init__()
5570 self.value = value
5671
5772
58 # Cache for the generated enums, to avoid name clash
59 _ENUM_CACHE = {}
73 def _deprecated_default_symbol_name(column_name, sort_asc):
74 return column_name + ("_asc" if sort_asc else "_desc")
6075
6176
62 def _sort_enum_for_model(cls, name=None, symbol_name=_symbol_name):
63 name = name or cls.__name__ + "SortEnum"
64 if name in _ENUM_CACHE:
65 return _ENUM_CACHE[name]
66 items = []
67 default = []
68 for column in inspect(cls).columns.values():
69 asc_name = symbol_name(column.name, True)
70 asc_value = EnumValue(asc_name, column.asc())
71 desc_name = symbol_name(column.name, False)
72 desc_value = EnumValue(desc_name, column.desc())
73 if column.primary_key:
74 default.append(asc_value)
75 items.extend(((asc_name, asc_value), (desc_name, desc_value)))
76 enum = Enum(name, items)
77 _ENUM_CACHE[name] = (enum, default)
78 return enum, default
77 # unfortunately, we cannot use lru_cache because we still support Python 2
78 _deprecated_object_type_cache = {}
7979
8080
81 def sort_enum_for_model(cls, name=None, symbol_name=_symbol_name):
82 """Create Graphene Enum for sorting a SQLAlchemy class query
81 def _deprecated_object_type_for_model(cls, name):
8382
84 Parameters
85 - cls : Sqlalchemy model class
86 Model used to create the sort enumerator
87 - name : str, optional, default None
88 Name to use for the enumerator. If not provided it will be set to `cls.__name__ + 'SortEnum'`
89 - symbol_name : function, optional, default `_symbol_name`
90 Function which takes the column name and a boolean indicating if the sort direction is ascending,
91 and returns the symbol name for the current column and sort direction.
92 The default function will create, for a column named 'foo', the symbols 'foo_asc' and 'foo_desc'
83 try:
84 return _deprecated_object_type_cache[cls, name]
85 except KeyError:
86 from .types import SQLAlchemyObjectType
9387
94 Returns
95 - Enum
96 The Graphene enumerator
88 obj_type_name = name or cls.__name__
89
90 class ObjType(SQLAlchemyObjectType):
91 class Meta:
92 name = obj_type_name
93 model = cls
94
95 _deprecated_object_type_cache[cls, name] = ObjType
96 return ObjType
97
98
99 def sort_enum_for_model(cls, name=None, symbol_name=None):
100 """Get a Graphene Enum for sorting the given model class.
101
102 This is deprecated, please use object_type.sort_enum() instead.
97103 """
98 enum, _ = _sort_enum_for_model(cls, name, symbol_name)
99 return enum
104 warnings.warn(
105 "sort_enum_for_model() is deprecated; use object_type.sort_enum() instead.",
106 DeprecationWarning,
107 stacklevel=2,
108 )
109
110 from .enums import sort_enum_for_object_type
111
112 return sort_enum_for_object_type(
113 _deprecated_object_type_for_model(cls, name),
114 name,
115 get_symbol_name=symbol_name or _deprecated_default_symbol_name,
116 )
100117
101118
102119 def sort_argument_for_model(cls, has_default=True):
103 """Returns a Graphene argument for the sort field that accepts a list of sorting directions for a model.
104 If `has_default` is True (the default) it will sort the result by the primary key(s)
120 """Get a Graphene Argument for sorting the given model class.
121
122 This is deprecated, please use object_type.sort_argument() instead.
105123 """
106 enum, default = _sort_enum_for_model(cls)
124 warnings.warn(
125 "sort_argument_for_model() is deprecated;"
126 " use object_type.sort_argument() instead.",
127 DeprecationWarning,
128 stacklevel=2,
129 )
130
131 from graphene import Argument, List
132 from .enums import sort_enum_for_object_type
133
134 enum = sort_enum_for_object_type(
135 _deprecated_object_type_for_model(cls, None),
136 get_symbol_name=_deprecated_default_symbol_name,
137 )
107138 if not has_default:
108 default = None
109 return Argument(List(enum), default_value=default)
139 enum.default = None
140
141 return Argument(List(enum), default_value=enum.default)
55 max-line-length = 120
66
77 [isort]
8 no_lines_before=FIRSTPARTY
89 known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme
910 known_first_party=graphene_sqlalchemy
10 known_third_party=flask,nameko,promise,py,pytest,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils
11 known_third_party=app,database,flask,graphql,mock,models,nameko,pkg_resources,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils
1112 sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER
12 no_lines_before=FIRSTPARTY
13 skip_glob=examples/nameko_sqlalchemy
1314
1415 [bdist_wheel]
1516 universal=1
1313 requirements = [
1414 # To keep things simple, we only support newer versions of Graphene
1515 "graphene>=2.1.3,<3",
16 "promise>=2.3",
1617 # Tests fail with 1.0.19
17 "SQLAlchemy>=1.1,<2",
18 "SQLAlchemy>=1.2,<2",
1819 "six>=1.10.0,<2",
1920 "singledispatch>=3.4.0.3,<4",
2021 ]
2829 "mock==2.0.0",
2930 "pytest-cov==2.6.1",
3031 "sqlalchemy_utils==0.33.9",
32 "pytest-benchmark==3.2.1",
3133 ]
3234
3335 setup(
4648 "Programming Language :: Python :: 2",
4749 "Programming Language :: Python :: 2.7",
4850 "Programming Language :: Python :: 3",
49 "Programming Language :: Python :: 3.3",
50 "Programming Language :: Python :: 3.4",
5151 "Programming Language :: Python :: 3.5",
5252 "Programming Language :: Python :: 3.6",
5353 "Programming Language :: Python :: 3.7",
5959 extras_require={
6060 "dev": [
6161 "tox==3.7.0", # Should be kept in sync with tox.ini
62 "coveralls==1.7.0",
62 "coveralls==1.10.0",
6363 "pre-commit==1.14.4",
6464 ],
6565 "test": tests_require,
00 [tox]
1 envlist = pre-commit,py{27,34,35,36,37}-sql{11,12,13}
1 envlist = pre-commit,py{27,35,36,37}-sql{11,12,13}
22 skipsdist = true
33 minversion = 3.7.0
44