Import upstream version 2.3.0
Kali Janitor
3 years ago
10 | 10 | # Distribution / packaging |
11 | 11 | .Python |
12 | 12 | env/ |
13 | .venv/ | |
13 | 14 | build/ |
14 | 15 | develop-eggs/ |
15 | 16 | dist/ |
24 | 25 | *.egg-info/ |
25 | 26 | .installed.cfg |
26 | 27 | *.egg |
28 | .python-version | |
27 | 29 | |
28 | 30 | # PyInstaller |
29 | 31 | # Usually these files are written by a python script from a template |
45 | 47 | coverage.xml |
46 | 48 | *,cover |
47 | 49 | .pytest_cache/ |
50 | .benchmarks/ | |
48 | 51 | |
49 | 52 | # Translations |
50 | 53 | *.mo |
3 | 3 | # Python 2.7 |
4 | 4 | - env: TOXENV=py27 |
5 | 5 | python: 2.7 |
6 | # Python 3.5 | |
7 | - env: TOXENV=py34 | |
8 | python: 3.4 | |
9 | 6 | # Python 3.5 |
10 | 7 | - env: TOXENV=py35 |
11 | 8 | python: 3.5 |
0 | / @cito @jnak @Nabellaleen |
42 | 42 | class User(SQLAlchemyObjectType): |
43 | 43 | class Meta: |
44 | 44 | model = UserModel |
45 | # only return specified fields | |
46 | only_fields = ("name",) | |
47 | # exclude specified fields | |
48 | exclude_fields = ("last_name",) | |
45 | # use `only_fields` to only expose specific fields ie "name" | |
46 | # only_fields = ("name",) | |
47 | # use `exclude_fields` to exclude specific fields ie "last_name" | |
48 | # exclude_fields = ("last_name",) | |
49 | 49 | |
50 | 50 | class Query(graphene.ObjectType): |
51 | 51 | users = graphene.List(User) |
12 | 12 | interfaces = (relay.Node,) |
13 | 13 | |
14 | 14 | |
15 | class BookConnection(relay.Connection): | |
16 | class Meta: | |
17 | node = Book | |
18 | ||
19 | ||
20 | 15 | class Author(SQLAlchemyObjectType): |
21 | 16 | class Meta: |
22 | 17 | model = AuthorModel |
23 | 18 | interfaces = (relay.Node,) |
24 | ||
25 | ||
26 | class AuthorConnection(relay.Connection): | |
27 | class Meta: | |
28 | node = Author | |
29 | 19 | |
30 | 20 | |
31 | 21 | class SearchResult(graphene.Union): |
38 | 28 | search = graphene.List(SearchResult, q=graphene.String()) # List field for search results |
39 | 29 | |
40 | 30 | # Normal Fields |
41 | all_books = SQLAlchemyConnectionField(BookConnection) | |
42 | all_authors = SQLAlchemyConnectionField(AuthorConnection) | |
31 | all_books = SQLAlchemyConnectionField(Book.connection) | |
32 | all_authors = SQLAlchemyConnectionField(Author.connection) | |
43 | 33 | |
44 | 34 | def resolve_search(self, info, **args): |
45 | 35 | q = args.get("q") # Search query |
49 | 49 | model = Pet |
50 | 50 | |
51 | 51 | |
52 | class PetConnection(Connection): | |
53 | class Meta: | |
54 | node = PetNode | |
55 | ||
56 | ||
57 | 52 | class Query(ObjectType): |
58 | allPets = SQLAlchemyConnectionField(PetConnection) | |
53 | allPets = SQLAlchemyConnectionField(PetNode.connection) | |
59 | 54 | |
60 | 55 | some of the allowed queries are |
61 | 56 |
101 | 101 | interfaces = (relay.Node, ) |
102 | 102 | |
103 | 103 | |
104 | class DepartmentConnection(relay.Connection): | |
105 | class Meta: | |
106 | node = Department | |
107 | ||
108 | ||
109 | 104 | class Employee(SQLAlchemyObjectType): |
110 | 105 | class Meta: |
111 | 106 | model = EmployeeModel |
112 | 107 | interfaces = (relay.Node, ) |
113 | 108 | |
114 | 109 | |
115 | class EmployeeConnection(relay.Connection): | |
116 | class Meta: | |
117 | node = Employee | |
118 | ||
119 | ||
120 | 110 | class Query(graphene.ObjectType): |
121 | 111 | node = relay.Node.Field() |
122 | 112 | # Allows sorting over multiple columns, by default over the primary key |
123 | all_employees = SQLAlchemyConnectionField(EmployeeConnection) | |
113 | all_employees = SQLAlchemyConnectionField(Employee.connection) | |
124 | 114 | # Disable sorting over this field |
125 | all_departments = SQLAlchemyConnectionField(DepartmentConnection, sort=None) | |
115 | all_departments = SQLAlchemyConnectionField(Department.connection, sort=None) | |
126 | 116 | |
127 | 117 | schema = graphene.Schema(query=Query) |
128 | 118 |
8 | 8 | --------------- |
9 | 9 | |
10 | 10 | First you'll need to get the source of the project. Do this by cloning the |
11 | whole Graphene repository: | |
11 | whole Graphene-SQLAlchemy repository: | |
12 | 12 | |
13 | 13 | ```bash |
14 | 14 | # Get the example project code |
0 | 0 | #!/usr/bin/env python |
1 | 1 | |
2 | from database import db_session, init_db | |
2 | 3 | from flask import Flask |
4 | from schema import schema | |
3 | 5 | |
4 | 6 | from flask_graphql import GraphQLView |
5 | ||
6 | from .database import db_session, init_db | |
7 | from .schema import schema | |
8 | 7 | |
9 | 8 | app = Flask(__name__) |
10 | 9 | app.debug = True |
11 | 10 | |
12 | default_query = ''' | |
11 | example_query = """ | |
13 | 12 | { |
14 | allEmployees { | |
13 | allEmployees(sort: [NAME_ASC, ID_ASC]) { | |
15 | 14 | edges { |
16 | 15 | node { |
17 | id, | |
18 | name, | |
16 | id | |
17 | name | |
19 | 18 | department { |
20 | id, | |
19 | id | |
21 | 20 | name |
22 | }, | |
21 | } | |
23 | 22 | role { |
24 | id, | |
23 | id | |
25 | 24 | name |
26 | 25 | } |
27 | 26 | } |
28 | 27 | } |
29 | 28 | } |
30 | }'''.strip() | |
29 | } | |
30 | """ | |
31 | 31 | |
32 | 32 | |
33 | app.add_url_rule('/graphql', view_func=GraphQLView.as_view('graphql', schema=schema, graphiql=True)) | |
33 | app.add_url_rule( | |
34 | "/graphql", view_func=GraphQLView.as_view("graphql", schema=schema, graphiql=True) | |
35 | ) | |
34 | 36 | |
35 | 37 | |
36 | 38 | @app.teardown_appcontext |
37 | 39 | def shutdown_session(exception=None): |
38 | 40 | db_session.remove() |
39 | 41 | |
40 | if __name__ == '__main__': | |
42 | ||
43 | if __name__ == "__main__": | |
41 | 44 | init_db() |
42 | 45 | app.run() |
13 | 13 | # import all modules here that might define models so that |
14 | 14 | # they will be registered properly on the metadata. Otherwise |
15 | 15 | # you will have to import them first before calling init_db() |
16 | from .models import Department, Employee, Role | |
16 | from models import Department, Employee, Role | |
17 | 17 | Base.metadata.drop_all(bind=engine) |
18 | 18 | Base.metadata.create_all(bind=engine) |
19 | 19 |
0 | from database import Base | |
0 | 1 | from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func |
1 | 2 | from sqlalchemy.orm import backref, relationship |
2 | ||
3 | from .database import Base | |
4 | 3 | |
5 | 4 | |
6 | 5 | class Department(Base): |
0 | graphene[sqlalchemy] | |
1 | SQLAlchemy==1.0.11 | |
2 | Flask==0.12.4 | |
3 | Flask-GraphQL==1.3.0 | |
0 | -e ../../ | |
1 | Flask-GraphQL |
0 | from models import Department as DepartmentModel | |
1 | from models import Employee as EmployeeModel | |
2 | from models import Role as RoleModel | |
3 | ||
0 | 4 | import graphene |
1 | 5 | from graphene import relay |
2 | from graphene_sqlalchemy import (SQLAlchemyConnectionField, | |
3 | SQLAlchemyObjectType, utils) | |
4 | ||
5 | from .models import Department as DepartmentModel | |
6 | from .models import Employee as EmployeeModel | |
7 | from .models import Role as RoleModel | |
6 | from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType | |
8 | 7 | |
9 | 8 | |
10 | 9 | class Department(SQLAlchemyObjectType): |
25 | 24 | interfaces = (relay.Node, ) |
26 | 25 | |
27 | 26 | |
28 | SortEnumEmployee = utils.sort_enum_for_model(EmployeeModel, 'SortEnumEmployee', | |
29 | lambda c, d: c.upper() + ('_ASC' if d else '_DESC')) | |
30 | ||
31 | ||
32 | 27 | class Query(graphene.ObjectType): |
33 | 28 | node = relay.Node.Field() |
34 | 29 | # Allow only single column sorting |
35 | 30 | all_employees = SQLAlchemyConnectionField( |
36 | Employee, | |
37 | sort=graphene.Argument( | |
38 | SortEnumEmployee, | |
39 | default_value=utils.EnumValue('id_asc', EmployeeModel.id.asc()))) | |
31 | Employee.connection, sort=Employee.sort_argument()) | |
40 | 32 | # Allows sorting over multiple columns, by default over the primary key |
41 | all_roles = SQLAlchemyConnectionField(Role) | |
33 | all_roles = SQLAlchemyConnectionField(Role.connection) | |
42 | 34 | # Disable sorting over this field |
43 | all_departments = SQLAlchemyConnectionField(Department, sort=None) | |
35 | all_departments = SQLAlchemyConnectionField(Department.connection, sort=None) | |
44 | 36 | |
45 | 37 | |
46 | schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) | |
38 | schema = graphene.Schema(query=Query) |
13 | 13 | --------------- |
14 | 14 | |
15 | 15 | First you'll need to get the source of the project. Do this by cloning the |
16 | whole Graphene repository: | |
16 | whole Graphene-SQLAlchemy repository: | |
17 | 17 | |
18 | 18 | ```bash |
19 | 19 | # Get the example project code |
45 | 45 | |
46 | 46 | ```bash |
47 | 47 | ./run.sh |
48 | ||
49 | 48 | ``` |
50 | 49 | |
51 | 50 | Now head on over to postman and send POST request to: |
0 | from database import db_session, init_db | |
1 | from schema import schema | |
2 | ||
0 | 3 | from graphql_server import (HttpQueryError, default_format_error, |
1 | 4 | encode_execution_results, json_encode, |
2 | 5 | load_json_body, run_http_query) |
3 | ||
4 | from .database import db_session, init_db | |
5 | from .schema import schema | |
6 | 6 | |
7 | 7 | |
8 | 8 | class App(): |
13 | 13 | # import all modules here that might define models so that |
14 | 14 | # they will be registered properly on the metadata. Otherwise |
15 | 15 | # you will have to import them first before calling init_db() |
16 | from .models import Department, Employee, Role | |
16 | from models import Department, Employee, Role | |
17 | 17 | Base.metadata.drop_all(bind=engine) |
18 | 18 | Base.metadata.create_all(bind=engine) |
19 | 19 |
0 | from database import Base | |
0 | 1 | from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func |
1 | 2 | from sqlalchemy.orm import backref, relationship |
2 | ||
3 | from .database import Base | |
4 | 3 | |
5 | 4 | |
6 | 5 | class Department(Base): |
0 | graphene[sqlalchemy] | |
1 | SQLAlchemy==1.0.11 | |
0 | -e ../../ | |
1 | graphql-server-core | |
2 | 2 | nameko |
3 | graphql-server-core |
0 | from models import Department as DepartmentModel | |
1 | from models import Employee as EmployeeModel | |
2 | from models import Role as RoleModel | |
3 | ||
0 | 4 | import graphene |
1 | 5 | from graphene import relay |
2 | 6 | from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType |
3 | 7 | |
4 | from .models import Department as DepartmentModel | |
5 | from .models import Employee as EmployeeModel | |
6 | from .models import Role as RoleModel | |
7 | ||
8 | 8 | |
9 | 9 | class Department(SQLAlchemyObjectType): |
10 | ||
11 | 10 | class Meta: |
12 | 11 | model = DepartmentModel |
13 | interfaces = (relay.Node, ) | |
12 | interfaces = (relay.Node,) | |
14 | 13 | |
15 | 14 | |
16 | 15 | class Employee(SQLAlchemyObjectType): |
17 | ||
18 | 16 | class Meta: |
19 | 17 | model = EmployeeModel |
20 | interfaces = (relay.Node, ) | |
18 | interfaces = (relay.Node,) | |
21 | 19 | |
22 | 20 | |
23 | 21 | class Role(SQLAlchemyObjectType): |
24 | ||
25 | 22 | class Meta: |
26 | 23 | model = RoleModel |
27 | interfaces = (relay.Node, ) | |
24 | interfaces = (relay.Node,) | |
28 | 25 | |
29 | 26 | |
30 | 27 | class Query(graphene.ObjectType): |
31 | 28 | node = relay.Node.Field() |
32 | all_employees = SQLAlchemyConnectionField(Employee) | |
33 | all_roles = SQLAlchemyConnectionField(Role) | |
29 | all_employees = SQLAlchemyConnectionField(Employee.connection) | |
30 | all_roles = SQLAlchemyConnectionField(Role.connection) | |
34 | 31 | role = graphene.Field(Role) |
35 | 32 | |
36 | 33 | |
37 | schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) | |
34 | schema = graphene.Schema(query=Query) |
0 | 0 | #!/usr/bin/env python |
1 | from app import App | |
1 | 2 | from nameko.web.handlers import http |
2 | ||
3 | from .app import App | |
4 | 3 | |
5 | 4 | |
6 | 5 | class DepartmentService: |
1 | 1 | from .fields import SQLAlchemyConnectionField |
2 | 2 | from .utils import get_query, get_session |
3 | 3 | |
4 | __version__ = "2.1.2" | |
4 | __version__ = "2.3.0" | |
5 | 5 | |
6 | 6 | __all__ = [ |
7 | 7 | "__version__", |
0 | import sqlalchemy | |
1 | from promise import dataloader, promise | |
2 | from sqlalchemy.orm import Session, strategies | |
3 | from sqlalchemy.orm.query import QueryContext | |
4 | ||
5 | ||
6 | def get_batch_resolver(relationship_prop): | |
7 | ||
8 | # Cache this across `batch_load_fn` calls | |
9 | # This is so SQL string generation is cached under-the-hood via `bakery` | |
10 | selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) | |
11 | ||
12 | class RelationshipLoader(dataloader.DataLoader): | |
13 | cache = False | |
14 | ||
15 | def batch_load_fn(self, parents): # pylint: disable=method-hidden | |
16 | """ | |
17 | Batch loads the relationships of all the parents as one SQL statement. | |
18 | ||
19 | There is no way to do this out-of-the-box with SQLAlchemy but | |
20 | we can piggyback on some internal APIs of the `selectin` | |
21 | eager loading strategy. It's a bit hacky but it's preferable | |
22 | than re-implementing and maintainnig a big chunk of the `selectin` | |
23 | loader logic ourselves. | |
24 | ||
25 | The approach here is to build a regular query that | |
26 | selects the parent and `selectin` load the relationship. | |
27 | But instead of having the query emits 2 `SELECT` statements | |
28 | when callling `all()`, we skip the first `SELECT` statement | |
29 | and jump right before the `selectin` loader is called. | |
30 | To accomplish this, we have to construct objects that are | |
31 | normally built in the first part of the query in order | |
32 | to call directly `SelectInLoader._load_for_path`. | |
33 | ||
34 | TODO Move this logic to a util in the SQLAlchemy repo as per | |
35 | SQLAlchemy's main maitainer suggestion. | |
36 | See https://git.io/JewQ7 | |
37 | """ | |
38 | child_mapper = relationship_prop.mapper | |
39 | parent_mapper = relationship_prop.parent | |
40 | session = Session.object_session(parents[0]) | |
41 | ||
42 | # These issues are very unlikely to happen in practice... | |
43 | for parent in parents: | |
44 | # assert parent.__mapper__ is parent_mapper | |
45 | # All instances must share the same session | |
46 | assert session is Session.object_session(parent) | |
47 | # The behavior of `selectin` is undefined if the parent is dirty | |
48 | assert parent not in session.dirty | |
49 | ||
50 | # Should the boolean be set to False? Does it matter for our purposes? | |
51 | states = [(sqlalchemy.inspect(parent), True) for parent in parents] | |
52 | ||
53 | # For our purposes, the query_context will only used to get the session | |
54 | query_context = QueryContext(session.query(parent_mapper.entity)) | |
55 | ||
56 | selectin_loader._load_for_path( | |
57 | query_context, | |
58 | parent_mapper._path_registry, | |
59 | states, | |
60 | None, | |
61 | child_mapper, | |
62 | ) | |
63 | ||
64 | return promise.Promise.resolve([getattr(parent, relationship_prop.key) for parent in parents]) | |
65 | ||
66 | loader = RelationshipLoader() | |
67 | ||
68 | def resolve(root, info, **args): | |
69 | return loader.load(root) | |
70 | ||
71 | return resolve |
0 | from enum import EnumMeta | |
1 | ||
0 | 2 | from singledispatch import singledispatch |
1 | 3 | from sqlalchemy import types |
2 | 4 | from sqlalchemy.dialects import postgresql |
3 | from sqlalchemy.orm import interfaces | |
5 | from sqlalchemy.orm import interfaces, strategies | |
4 | 6 | |
5 | 7 | from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List, |
6 | 8 | String) |
7 | 9 | from graphene.types.json import JSONString |
10 | ||
11 | from .batching import get_batch_resolver | |
12 | from .enums import enum_for_sa_enum | |
13 | from .fields import (BatchSQLAlchemyConnectionField, | |
14 | default_connection_field_factory) | |
15 | from .registry import get_global_registry | |
16 | from .resolvers import get_attr_resolver, get_custom_resolver | |
8 | 17 | |
9 | 18 | try: |
10 | 19 | from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType |
12 | 21 | ChoiceType = JSONType = ScalarListType = TSVectorType = object |
13 | 22 | |
14 | 23 | |
24 | is_selectin_available = getattr(strategies, 'SelectInLoader', None) | |
25 | ||
26 | ||
15 | 27 | def get_column_doc(column): |
16 | 28 | return getattr(column, "doc", None) |
17 | 29 | |
20 | 32 | return bool(getattr(column, "nullable", True)) |
21 | 33 | |
22 | 34 | |
23 | def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory): | |
24 | direction = relationship.direction | |
25 | model = relationship.mapper.entity | |
26 | ||
35 | def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching, | |
36 | orm_field_name, **field_kwargs): | |
37 | """ | |
38 | :param sqlalchemy.RelationshipProperty relationship_prop: | |
39 | :param SQLAlchemyObjectType obj_type: | |
40 | :param function|None connection_field_factory: | |
41 | :param bool batching: | |
42 | :param str orm_field_name: | |
43 | :param dict field_kwargs: | |
44 | :rtype: Dynamic | |
45 | """ | |
27 | 46 | def dynamic_type(): |
28 | _type = registry.get_type_for_model(model) | |
29 | if not _type: | |
47 | """:rtype: Field|None""" | |
48 | direction = relationship_prop.direction | |
49 | child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) | |
50 | batching_ = batching if is_selectin_available else False | |
51 | ||
52 | if not child_type: | |
30 | 53 | return None |
31 | if direction == interfaces.MANYTOONE or not relationship.uselist: | |
32 | return Field(_type) | |
33 | elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): | |
34 | if _type._meta.connection: | |
35 | return connection_field_factory(relationship, registry) | |
36 | return Field(List(_type)) | |
54 | ||
55 | if direction == interfaces.MANYTOONE or not relationship_prop.uselist: | |
56 | return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name, | |
57 | **field_kwargs) | |
58 | ||
59 | if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): | |
60 | return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_, | |
61 | connection_field_factory, **field_kwargs) | |
37 | 62 | |
38 | 63 | return Dynamic(dynamic_type) |
39 | 64 | |
40 | 65 | |
41 | def convert_sqlalchemy_hybrid_method(hybrid_item): | |
42 | return String(description=getattr(hybrid_item, "__doc__", None), required=False) | |
43 | ||
44 | ||
45 | def convert_sqlalchemy_composite(composite, registry): | |
46 | converter = registry.get_converter_for_composite(composite.composite_class) | |
66 | def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs): | |
67 | """ | |
68 | Convert one-to-one or many-to-one relationshsip. Return an object field. | |
69 | ||
70 | :param sqlalchemy.RelationshipProperty relationship_prop: | |
71 | :param SQLAlchemyObjectType obj_type: | |
72 | :param bool batching: | |
73 | :param str orm_field_name: | |
74 | :param dict field_kwargs: | |
75 | :rtype: Field | |
76 | """ | |
77 | child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) | |
78 | ||
79 | resolver = get_custom_resolver(obj_type, orm_field_name) | |
80 | if resolver is None: | |
81 | resolver = get_batch_resolver(relationship_prop) if batching else \ | |
82 | get_attr_resolver(obj_type, relationship_prop.key) | |
83 | ||
84 | return Field(child_type, resolver=resolver, **field_kwargs) | |
85 | ||
86 | ||
87 | def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs): | |
88 | """ | |
89 | Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field. | |
90 | ||
91 | :param sqlalchemy.RelationshipProperty relationship_prop: | |
92 | :param SQLAlchemyObjectType obj_type: | |
93 | :param bool batching: | |
94 | :param function|None connection_field_factory: | |
95 | :param dict field_kwargs: | |
96 | :rtype: Field | |
97 | """ | |
98 | child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity) | |
99 | ||
100 | if not child_type._meta.connection: | |
101 | return Field(List(child_type), **field_kwargs) | |
102 | ||
103 | # TODO Allow override of connection_field_factory and resolver via ORMField | |
104 | if connection_field_factory is None: | |
105 | connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \ | |
106 | default_connection_field_factory | |
107 | ||
108 | return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs) | |
109 | ||
110 | ||
111 | def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs): | |
112 | if 'type' not in field_kwargs: | |
113 | # TODO The default type should be dependent on the type of the property propety. | |
114 | field_kwargs['type'] = String | |
115 | ||
116 | return Field( | |
117 | resolver=resolver, | |
118 | **field_kwargs | |
119 | ) | |
120 | ||
121 | ||
122 | def convert_sqlalchemy_composite(composite_prop, registry, resolver): | |
123 | converter = registry.get_converter_for_composite(composite_prop.composite_class) | |
47 | 124 | if not converter: |
48 | 125 | try: |
49 | 126 | raise Exception( |
50 | 127 | "Don't know how to convert the composite field %s (%s)" |
51 | % (composite, composite.composite_class) | |
128 | % (composite_prop, composite_prop.composite_class) | |
52 | 129 | ) |
53 | 130 | except AttributeError: |
54 | 131 | # handle fields that are not attached to a class yet (don't have a parent) |
55 | 132 | raise Exception( |
56 | 133 | "Don't know how to convert the composite field %r (%s)" |
57 | % (composite, composite.composite_class) | |
134 | % (composite_prop, composite_prop.composite_class) | |
58 | 135 | ) |
59 | return converter(composite, registry) | |
136 | ||
137 | # TODO Add a way to override composite fields default parameters | |
138 | return converter(composite_prop, registry) | |
60 | 139 | |
61 | 140 | |
62 | 141 | def _register_composite_class(cls, registry=None): |
74 | 153 | convert_sqlalchemy_composite.register = _register_composite_class |
75 | 154 | |
76 | 155 | |
77 | def convert_sqlalchemy_column(column, registry=None): | |
78 | return convert_sqlalchemy_type(getattr(column, "type", None), column, registry) | |
156 | def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): | |
157 | column = column_prop.columns[0] | |
158 | field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry)) | |
159 | field_kwargs.setdefault('required', not is_column_nullable(column)) | |
160 | field_kwargs.setdefault('description', get_column_doc(column)) | |
161 | ||
162 | return Field( | |
163 | resolver=resolver, | |
164 | **field_kwargs | |
165 | ) | |
79 | 166 | |
80 | 167 | |
81 | 168 | @singledispatch |
97 | 184 | @convert_sqlalchemy_type.register(postgresql.CIDR) |
98 | 185 | @convert_sqlalchemy_type.register(TSVectorType) |
99 | 186 | def convert_column_to_string(type, column, registry=None): |
100 | return String( | |
101 | description=get_column_doc(column), required=not (is_column_nullable(column)) | |
102 | ) | |
187 | return String | |
103 | 188 | |
104 | 189 | |
105 | 190 | @convert_sqlalchemy_type.register(types.DateTime) |
106 | 191 | def convert_column_to_datetime(type, column, registry=None): |
107 | 192 | from graphene.types.datetime import DateTime |
108 | ||
109 | return DateTime( | |
110 | description=get_column_doc(column), required=not (is_column_nullable(column)) | |
111 | ) | |
193 | return DateTime | |
112 | 194 | |
113 | 195 | |
114 | 196 | @convert_sqlalchemy_type.register(types.SmallInteger) |
115 | 197 | @convert_sqlalchemy_type.register(types.Integer) |
116 | 198 | def convert_column_to_int_or_id(type, column, registry=None): |
117 | if column.primary_key: | |
118 | return ID( | |
119 | description=get_column_doc(column), | |
120 | required=not (is_column_nullable(column)), | |
121 | ) | |
122 | else: | |
123 | return Int( | |
124 | description=get_column_doc(column), | |
125 | required=not (is_column_nullable(column)), | |
126 | ) | |
199 | return ID if column.primary_key else Int | |
127 | 200 | |
128 | 201 | |
129 | 202 | @convert_sqlalchemy_type.register(types.Boolean) |
130 | 203 | def convert_column_to_boolean(type, column, registry=None): |
131 | return Boolean( | |
132 | description=get_column_doc(column), required=not (is_column_nullable(column)) | |
133 | ) | |
204 | return Boolean | |
134 | 205 | |
135 | 206 | |
136 | 207 | @convert_sqlalchemy_type.register(types.Float) |
137 | 208 | @convert_sqlalchemy_type.register(types.Numeric) |
138 | 209 | @convert_sqlalchemy_type.register(types.BigInteger) |
139 | 210 | def convert_column_to_float(type, column, registry=None): |
140 | return Float( | |
141 | description=get_column_doc(column), required=not (is_column_nullable(column)) | |
142 | ) | |
211 | return Float | |
143 | 212 | |
144 | 213 | |
145 | 214 | @convert_sqlalchemy_type.register(types.Enum) |
146 | 215 | def convert_enum_to_enum(type, column, registry=None): |
147 | enum_class = getattr(type, 'enum_class', None) | |
148 | if enum_class: # Check if an enum.Enum type is used | |
149 | graphene_type = Enum.from_enum(enum_class) | |
150 | else: # Nope, just a list of string options | |
151 | items = zip(type.enums, type.enums) | |
152 | graphene_type = Enum(type.name, items) | |
153 | return Field( | |
154 | graphene_type, | |
155 | description=get_column_doc(column), | |
156 | required=not (is_column_nullable(column)), | |
157 | ) | |
158 | ||
159 | ||
216 | return lambda: enum_for_sa_enum(type, registry or get_global_registry()) | |
217 | ||
218 | ||
219 | # TODO Make ChoiceType conversion consistent with other enums | |
160 | 220 | @convert_sqlalchemy_type.register(ChoiceType) |
161 | def convert_column_to_enum(type, column, registry=None): | |
221 | def convert_choice_to_enum(type, column, registry=None): | |
162 | 222 | name = "{}_{}".format(column.table.name, column.name).upper() |
163 | return Enum(name, type.choices, description=get_column_doc(column)) | |
223 | if isinstance(type.choices, EnumMeta): | |
224 | # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta | |
225 | # do not use from_enum here because we can have more than one enum column in table | |
226 | return Enum(name, list((v.name, v.value) for v in type.choices)) | |
227 | else: | |
228 | return Enum(name, type.choices) | |
164 | 229 | |
165 | 230 | |
166 | 231 | @convert_sqlalchemy_type.register(ScalarListType) |
167 | 232 | def convert_scalar_list_to_list(type, column, registry=None): |
168 | return List(String, description=get_column_doc(column)) | |
169 | ||
170 | ||
233 | return List(String) | |
234 | ||
235 | ||
236 | @convert_sqlalchemy_type.register(types.ARRAY) | |
171 | 237 | @convert_sqlalchemy_type.register(postgresql.ARRAY) |
172 | def convert_postgres_array_to_list(_type, column, registry=None): | |
173 | graphene_type = convert_sqlalchemy_type(column.type.item_type, column) | |
174 | inner_type = type(graphene_type) | |
175 | return List( | |
176 | inner_type, | |
177 | description=get_column_doc(column), | |
178 | required=not (is_column_nullable(column)), | |
179 | ) | |
238 | def convert_array_to_list(_type, column, registry=None): | |
239 | inner_type = convert_sqlalchemy_type(column.type.item_type, column) | |
240 | return List(inner_type) | |
180 | 241 | |
181 | 242 | |
182 | 243 | @convert_sqlalchemy_type.register(postgresql.HSTORE) |
183 | 244 | @convert_sqlalchemy_type.register(postgresql.JSON) |
184 | 245 | @convert_sqlalchemy_type.register(postgresql.JSONB) |
185 | 246 | def convert_json_to_string(type, column, registry=None): |
186 | return JSONString( | |
187 | description=get_column_doc(column), required=not (is_column_nullable(column)) | |
188 | ) | |
247 | return JSONString | |
189 | 248 | |
190 | 249 | |
191 | 250 | @convert_sqlalchemy_type.register(JSONType) |
192 | 251 | def convert_json_type_to_string(type, column, registry=None): |
193 | return JSONString( | |
194 | description=get_column_doc(column), required=not (is_column_nullable(column)) | |
195 | ) | |
252 | return JSONString |
0 | import six | |
1 | from sqlalchemy.orm import ColumnProperty | |
2 | from sqlalchemy.types import Enum as SQLAlchemyEnumType | |
3 | ||
4 | from graphene import Argument, Enum, List | |
5 | ||
6 | from .utils import EnumValue, to_enum_value_name, to_type_name | |
7 | ||
8 | ||
9 | def _convert_sa_to_graphene_enum(sa_enum, fallback_name=None): | |
10 | """Convert the given SQLAlchemy Enum type to a Graphene Enum type. | |
11 | ||
12 | The name of the Graphene Enum will be determined as follows: | |
13 | If the SQLAlchemy Enum is based on a Python Enum, use the name | |
14 | of the Python Enum. Otherwise, if the SQLAlchemy Enum is named, | |
15 | use the SQL name after conversion to a type name. Otherwise, use | |
16 | the given fallback_name or raise an error if it is empty. | |
17 | ||
18 | The Enum value names are converted to upper case if necessary. | |
19 | """ | |
20 | if not isinstance(sa_enum, SQLAlchemyEnumType): | |
21 | raise TypeError( | |
22 | "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) | |
23 | ) | |
24 | enum_class = sa_enum.enum_class | |
25 | if enum_class: | |
26 | if all(to_enum_value_name(key) == key for key in enum_class.__members__): | |
27 | return Enum.from_enum(enum_class) | |
28 | name = enum_class.__name__ | |
29 | members = [ | |
30 | (to_enum_value_name(key), value.value) | |
31 | for key, value in enum_class.__members__.items() | |
32 | ] | |
33 | else: | |
34 | sql_enum_name = sa_enum.name | |
35 | if sql_enum_name: | |
36 | name = to_type_name(sql_enum_name) | |
37 | elif fallback_name: | |
38 | name = fallback_name | |
39 | else: | |
40 | raise TypeError("No type name specified for {!r}".format(sa_enum)) | |
41 | members = [(to_enum_value_name(key), key) for key in sa_enum.enums] | |
42 | return Enum(name, members) | |
43 | ||
44 | ||
45 | def enum_for_sa_enum(sa_enum, registry): | |
46 | """Return the Graphene Enum type for the specified SQLAlchemy Enum type.""" | |
47 | if not isinstance(sa_enum, SQLAlchemyEnumType): | |
48 | raise TypeError( | |
49 | "Expected sqlalchemy.types.Enum, but got: {!r}".format(sa_enum) | |
50 | ) | |
51 | enum = registry.get_graphene_enum_for_sa_enum(sa_enum) | |
52 | if not enum: | |
53 | enum = _convert_sa_to_graphene_enum(sa_enum) | |
54 | registry.register_enum(sa_enum, enum) | |
55 | return enum | |
56 | ||
57 | ||
58 | def enum_for_field(obj_type, field_name): | |
59 | """Return the Graphene Enum type for the specified Graphene field.""" | |
60 | from .types import SQLAlchemyObjectType | |
61 | ||
62 | if not isinstance(obj_type, type) or not issubclass(obj_type, SQLAlchemyObjectType): | |
63 | raise TypeError( | |
64 | "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)) | |
65 | if not field_name or not isinstance(field_name, six.string_types): | |
66 | raise TypeError( | |
67 | "Expected a field name, but got: {!r}".format(field_name)) | |
68 | registry = obj_type._meta.registry | |
69 | orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) | |
70 | if orm_field is None: | |
71 | raise TypeError("Cannot get {}.{}".format(obj_type._meta.name, field_name)) | |
72 | if not isinstance(orm_field, ColumnProperty): | |
73 | raise TypeError( | |
74 | "{}.{} does not map to model column".format(obj_type._meta.name, field_name) | |
75 | ) | |
76 | column = orm_field.columns[0] | |
77 | sa_enum = column.type | |
78 | if not isinstance(sa_enum, SQLAlchemyEnumType): | |
79 | raise TypeError( | |
80 | "{}.{} does not map to enum column".format(obj_type._meta.name, field_name) | |
81 | ) | |
82 | enum = registry.get_graphene_enum_for_sa_enum(sa_enum) | |
83 | if not enum: | |
84 | fallback_name = obj_type._meta.name + to_type_name(field_name) | |
85 | enum = _convert_sa_to_graphene_enum(sa_enum, fallback_name) | |
86 | registry.register_enum(sa_enum, enum) | |
87 | return enum | |
88 | ||
89 | ||
90 | def _default_sort_enum_symbol_name(column_name, sort_asc=True): | |
91 | return to_enum_value_name(column_name) + ("_ASC" if sort_asc else "_DESC") | |
92 | ||
93 | ||
94 | def sort_enum_for_object_type( | |
95 | obj_type, name=None, only_fields=None, only_indexed=None, get_symbol_name=None | |
96 | ): | |
97 | """Return Graphene Enum for sorting the given SQLAlchemyObjectType. | |
98 | ||
99 | Parameters | |
100 | - obj_type : SQLAlchemyObjectType | |
101 | The object type for which the sort Enum shall be generated. | |
102 | - name : str, optional, default None | |
103 | Name to use for the sort Enum. | |
104 | If not provided, it will be set to the object type name + 'SortEnum' | |
105 | - only_fields : sequence, optional, default None | |
106 | If this is set, only fields from this sequence will be considered. | |
107 | - only_indexed : bool, optional, default False | |
108 | If this is set, only indexed columns will be considered. | |
109 | - get_symbol_name : function, optional, default None | |
110 | Function which takes the column name and a boolean indicating | |
111 | if the sort direction is ascending, and returns the symbol name | |
112 | for the current column and sort direction. If no such function | |
113 | is passed, a default function will be used that creates the symbols | |
114 | 'foo_asc' and 'foo_desc' for a column with the name 'foo'. | |
115 | ||
116 | Returns | |
117 | - Enum | |
118 | The Graphene Enum type | |
119 | """ | |
120 | name = name or obj_type._meta.name + "SortEnum" | |
121 | registry = obj_type._meta.registry | |
122 | enum = registry.get_sort_enum_for_object_type(obj_type) | |
123 | custom_options = dict( | |
124 | only_fields=only_fields, | |
125 | only_indexed=only_indexed, | |
126 | get_symbol_name=get_symbol_name, | |
127 | ) | |
128 | if enum: | |
129 | if name != enum.__name__ or custom_options != enum.custom_options: | |
130 | raise ValueError( | |
131 | "Sort enum for {} has already been customized".format(obj_type) | |
132 | ) | |
133 | else: | |
134 | members = [] | |
135 | default = [] | |
136 | fields = obj_type._meta.fields | |
137 | get_name = get_symbol_name or _default_sort_enum_symbol_name | |
138 | for field_name in fields: | |
139 | if only_fields and field_name not in only_fields: | |
140 | continue | |
141 | orm_field = registry.get_orm_field_for_graphene_field(obj_type, field_name) | |
142 | if not isinstance(orm_field, ColumnProperty): | |
143 | continue | |
144 | column = orm_field.columns[0] | |
145 | if only_indexed and not (column.primary_key or column.index): | |
146 | continue | |
147 | asc_name = get_name(column.name, True) | |
148 | asc_value = EnumValue(asc_name, column.asc()) | |
149 | desc_name = get_name(column.name, False) | |
150 | desc_value = EnumValue(desc_name, column.desc()) | |
151 | if column.primary_key: | |
152 | default.append(asc_value) | |
153 | members.extend(((asc_name, asc_value), (desc_name, desc_value))) | |
154 | enum = Enum(name, members) | |
155 | enum.default = default # store default as attribute | |
156 | enum.custom_options = custom_options | |
157 | registry.register_sort_enum(obj_type, enum) | |
158 | return enum | |
159 | ||
160 | ||
161 | def sort_argument_for_object_type( | |
162 | obj_type, | |
163 | enum_name=None, | |
164 | only_fields=None, | |
165 | only_indexed=None, | |
166 | get_symbol_name=None, | |
167 | has_default=True, | |
168 | ): | |
169 | """"Returns Graphene Argument for sorting the given SQLAlchemyObjectType. | |
170 | ||
171 | Parameters | |
172 | - obj_type : SQLAlchemyObjectType | |
173 | The object type for which the sort Argument shall be generated. | |
174 | - enum_name : str, optional, default None | |
175 | Name to use for the sort Enum. | |
176 | If not provided, it will be set to the object type name + 'SortEnum' | |
177 | - only_fields : sequence, optional, default None | |
178 | If this is set, only fields from this sequence will be considered. | |
179 | - only_indexed : bool, optional, default False | |
180 | If this is set, only indexed columns will be considered. | |
181 | - get_symbol_name : function, optional, default None | |
182 | Function which takes the column name and a boolean indicating | |
183 | if the sort direction is ascending, and returns the symbol name | |
184 | for the current column and sort direction. If no such function | |
185 | is passed, a default function will be used that creates the symbols | |
186 | 'foo_asc' and 'foo_desc' for a column with the name 'foo'. | |
187 | - has_default : bool, optional, default True | |
188 | If this is set to False, no sorting will happen when this argument is not | |
189 | passed. Otherwise results will be sortied by the primary key(s) of the model. | |
190 | ||
191 | Returns | |
192 | - Enum | |
193 | A Graphene Argument that accepts a list of sorting directions for the model. | |
194 | """ | |
195 | enum = sort_enum_for_object_type( | |
196 | obj_type, | |
197 | enum_name, | |
198 | only_fields=only_fields, | |
199 | only_indexed=only_indexed, | |
200 | get_symbol_name=get_symbol_name, | |
201 | ) | |
202 | if not has_default: | |
203 | enum.default = None | |
204 | ||
205 | return Argument(List(enum), default_value=enum.default) |
0 | import logging | |
0 | import warnings | |
1 | 1 | from functools import partial |
2 | 2 | |
3 | import six | |
3 | 4 | from promise import Promise, is_thenable |
4 | 5 | from sqlalchemy.orm.query import Query |
5 | 6 | |
7 | from graphene import NonNull | |
6 | 8 | from graphene.relay import Connection, ConnectionField |
7 | 9 | from graphene.relay.connection import PageInfo |
8 | 10 | from graphql_relay.connection.arrayconnection import connection_from_list_slice |
9 | 11 | |
10 | from .utils import get_query, sort_argument_for_model | |
11 | ||
12 | log = logging.getLogger() | |
12 | from .batching import get_batch_resolver | |
13 | from .utils import get_query | |
13 | 14 | |
14 | 15 | |
15 | 16 | class UnsortedSQLAlchemyConnectionField(ConnectionField): |
18 | 19 | from .types import SQLAlchemyObjectType |
19 | 20 | |
20 | 21 | _type = super(ConnectionField, self).type |
21 | if issubclass(_type, Connection): | |
22 | nullable_type = get_nullable_type(_type) | |
23 | if issubclass(nullable_type, Connection): | |
22 | 24 | return _type |
23 | assert issubclass(_type, SQLAlchemyObjectType), ( | |
25 | assert issubclass(nullable_type, SQLAlchemyObjectType), ( | |
24 | 26 | "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" |
25 | ).format(_type.__name__) | |
26 | assert _type._meta.connection, "The type {} doesn't have a connection".format( | |
27 | _type.__name__ | |
27 | ).format(nullable_type.__name__) | |
28 | assert ( | |
29 | nullable_type.connection | |
30 | ), "The type {} doesn't have a connection".format( | |
31 | nullable_type.__name__ | |
28 | 32 | ) |
29 | return _type._meta.connection | |
33 | assert _type == nullable_type, ( | |
34 | "Passing a SQLAlchemyObjectType instance is deprecated. " | |
35 | "Pass the connection type instead accessible via SQLAlchemyObjectType.connection" | |
36 | ) | |
37 | return nullable_type.connection | |
30 | 38 | |
31 | 39 | @property |
32 | 40 | def model(self): |
33 | return self.type._meta.node._meta.model | |
41 | return get_nullable_type(self.type)._meta.node._meta.model | |
34 | 42 | |
35 | 43 | @classmethod |
36 | def get_query(cls, model, info, sort=None, **args): | |
37 | query = get_query(model, info.context) | |
38 | if sort is not None: | |
39 | if isinstance(sort, str): | |
40 | query = query.order_by(sort.value) | |
41 | else: | |
42 | query = query.order_by(*(col.value for col in sort)) | |
43 | return query | |
44 | def get_query(cls, model, info, **args): | |
45 | return get_query(model, info.context) | |
44 | 46 | |
45 | 47 | @classmethod |
46 | 48 | def resolve_connection(cls, connection_type, model, info, args, resolved): |
75 | 77 | return on_resolve(resolved) |
76 | 78 | |
77 | 79 | def get_resolver(self, parent_resolver): |
78 | return partial(self.connection_resolver, parent_resolver, self.type, self.model) | |
80 | return partial( | |
81 | self.connection_resolver, | |
82 | parent_resolver, | |
83 | get_nullable_type(self.type), | |
84 | self.model, | |
85 | ) | |
79 | 86 | |
80 | 87 | |
88 | # TODO Rename this to SortableSQLAlchemyConnectionField | |
81 | 89 | class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): |
82 | 90 | def __init__(self, type, *args, **kwargs): |
83 | if "sort" not in kwargs and issubclass(type, Connection): | |
91 | nullable_type = get_nullable_type(type) | |
92 | if "sort" not in kwargs and issubclass(nullable_type, Connection): | |
84 | 93 | # Let super class raise if type is not a Connection |
85 | 94 | try: |
86 | model = type.Edge.node._type._meta.model | |
87 | kwargs.setdefault("sort", sort_argument_for_model(model)) | |
88 | except Exception: | |
89 | raise Exception( | |
95 | kwargs.setdefault("sort", nullable_type.Edge.node._type.sort_argument()) | |
96 | except (AttributeError, TypeError): | |
97 | raise TypeError( | |
90 | 98 | 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' |
91 | 99 | " to None to disabling the creation of the sort query argument".format( |
92 | type.__name__ | |
100 | nullable_type.__name__ | |
93 | 101 | ) |
94 | 102 | ) |
95 | 103 | elif "sort" in kwargs and kwargs["sort"] is None: |
96 | 104 | del kwargs["sort"] |
97 | 105 | super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) |
98 | 106 | |
107 | @classmethod | |
108 | def get_query(cls, model, info, sort=None, **args): | |
109 | query = get_query(model, info.context) | |
110 | if sort is not None: | |
111 | if isinstance(sort, six.string_types): | |
112 | query = query.order_by(sort.value) | |
113 | else: | |
114 | query = query.order_by(*(col.value for col in sort)) | |
115 | return query | |
99 | 116 | |
100 | def default_connection_field_factory(relationship, registry): | |
117 | ||
118 | class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): | |
119 | """ | |
120 | This is currently experimental. | |
121 | The API and behavior may change in future versions. | |
122 | Use at your own risk. | |
123 | """ | |
124 | ||
125 | def get_resolver(self, parent_resolver): | |
126 | return partial( | |
127 | self.connection_resolver, | |
128 | self.resolver, | |
129 | get_nullable_type(self.type), | |
130 | self.model, | |
131 | ) | |
132 | ||
133 | @classmethod | |
134 | def from_relationship(cls, relationship, registry, **field_kwargs): | |
135 | model = relationship.mapper.entity | |
136 | model_type = registry.get_type_for_model(model) | |
137 | return cls(model_type.connection, resolver=get_batch_resolver(relationship), **field_kwargs) | |
138 | ||
139 | ||
140 | def default_connection_field_factory(relationship, registry, **field_kwargs): | |
101 | 141 | model = relationship.mapper.entity |
102 | 142 | model_type = registry.get_type_for_model(model) |
103 | return createConnectionField(model_type) | |
143 | return __connectionFactory(model_type, **field_kwargs) | |
104 | 144 | |
105 | 145 | |
106 | 146 | # TODO Remove in next major version |
107 | 147 | __connectionFactory = UnsortedSQLAlchemyConnectionField |
108 | 148 | |
109 | 149 | |
110 | def createConnectionField(_type): | |
111 | log.warn( | |
150 | def createConnectionField(_type, **field_kwargs): | |
151 | warnings.warn( | |
112 | 152 | 'createConnectionField is deprecated and will be removed in the next ' |
113 | 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' | |
153 | 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', | |
154 | DeprecationWarning, | |
114 | 155 | ) |
115 | return __connectionFactory(_type) | |
156 | return __connectionFactory(_type, **field_kwargs) | |
116 | 157 | |
117 | 158 | |
118 | 159 | def registerConnectionFieldFactory(factoryMethod): |
119 | log.warn( | |
160 | warnings.warn( | |
120 | 161 | 'registerConnectionFieldFactory is deprecated and will be removed in the next ' |
121 | 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' | |
162 | 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', | |
163 | DeprecationWarning, | |
122 | 164 | ) |
123 | 165 | global __connectionFactory |
124 | 166 | __connectionFactory = factoryMethod |
125 | 167 | |
126 | 168 | |
127 | 169 | def unregisterConnectionFieldFactory(): |
128 | log.warn( | |
170 | warnings.warn( | |
129 | 171 | 'registerConnectionFieldFactory is deprecated and will be removed in the next ' |
130 | 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' | |
172 | 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.', | |
173 | DeprecationWarning, | |
131 | 174 | ) |
132 | 175 | global __connectionFactory |
133 | 176 | __connectionFactory = UnsortedSQLAlchemyConnectionField |
177 | ||
178 | ||
179 | def get_nullable_type(_type): | |
180 | if isinstance(_type, NonNull): | |
181 | return _type.of_type | |
182 | return _type |
0 | from collections import defaultdict | |
1 | ||
2 | import six | |
3 | from sqlalchemy.types import Enum as SQLAlchemyEnumType | |
4 | ||
5 | from graphene import Enum | |
6 | ||
7 | ||
0 | 8 | class Registry(object): |
1 | 9 | def __init__(self): |
2 | 10 | self._registry = {} |
3 | 11 | self._registry_models = {} |
12 | self._registry_orm_fields = defaultdict(dict) | |
4 | 13 | self._registry_composites = {} |
14 | self._registry_enums = {} | |
15 | self._registry_sort_enums = {} | |
5 | 16 | |
6 | def register(self, cls): | |
17 | def register(self, obj_type): | |
7 | 18 | from .types import SQLAlchemyObjectType |
8 | 19 | |
9 | assert issubclass(cls, SQLAlchemyObjectType), ( | |
10 | "Only classes of type SQLAlchemyObjectType can be registered, " | |
11 | 'received "{}"' | |
12 | ).format(cls.__name__) | |
13 | assert cls._meta.registry == self, "Registry for a Model have to match." | |
20 | if not isinstance(obj_type, type) or not issubclass( | |
21 | obj_type, SQLAlchemyObjectType | |
22 | ): | |
23 | raise TypeError( | |
24 | "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) | |
25 | ) | |
26 | assert obj_type._meta.registry == self, "Registry for a Model have to match." | |
14 | 27 | # assert self.get_type_for_model(cls._meta.model) in [None, cls], ( |
15 | 28 | # 'SQLAlchemy model "{}" already associated with ' |
16 | 29 | # 'another type "{}".' |
17 | 30 | # ).format(cls._meta.model, self._registry[cls._meta.model]) |
18 | self._registry[cls._meta.model] = cls | |
31 | self._registry[obj_type._meta.model] = obj_type | |
19 | 32 | |
20 | 33 | def get_type_for_model(self, model): |
21 | 34 | return self._registry.get(model) |
35 | ||
36 | def register_orm_field(self, obj_type, field_name, orm_field): | |
37 | from .types import SQLAlchemyObjectType | |
38 | ||
39 | if not isinstance(obj_type, type) or not issubclass( | |
40 | obj_type, SQLAlchemyObjectType | |
41 | ): | |
42 | raise TypeError( | |
43 | "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) | |
44 | ) | |
45 | if not field_name or not isinstance(field_name, six.string_types): | |
46 | raise TypeError("Expected a field name, but got: {!r}".format(field_name)) | |
47 | self._registry_orm_fields[obj_type][field_name] = orm_field | |
48 | ||
49 | def get_orm_field_for_graphene_field(self, obj_type, field_name): | |
50 | return self._registry_orm_fields.get(obj_type, {}).get(field_name) | |
22 | 51 | |
23 | 52 | def register_composite_converter(self, composite, converter): |
24 | 53 | self._registry_composites[composite] = converter |
25 | 54 | |
26 | 55 | def get_converter_for_composite(self, composite): |
27 | 56 | return self._registry_composites.get(composite) |
57 | ||
58 | def register_enum(self, sa_enum, graphene_enum): | |
59 | if not isinstance(sa_enum, SQLAlchemyEnumType): | |
60 | raise TypeError( | |
61 | "Expected SQLAlchemyEnumType, but got: {!r}".format(sa_enum) | |
62 | ) | |
63 | if not isinstance(graphene_enum, type(Enum)): | |
64 | raise TypeError( | |
65 | "Expected Graphene Enum, but got: {!r}".format(graphene_enum) | |
66 | ) | |
67 | ||
68 | self._registry_enums[sa_enum] = graphene_enum | |
69 | ||
70 | def get_graphene_enum_for_sa_enum(self, sa_enum): | |
71 | return self._registry_enums.get(sa_enum) | |
72 | ||
73 | def register_sort_enum(self, obj_type, sort_enum): | |
74 | from .types import SQLAlchemyObjectType | |
75 | ||
76 | if not isinstance(obj_type, type) or not issubclass( | |
77 | obj_type, SQLAlchemyObjectType | |
78 | ): | |
79 | raise TypeError( | |
80 | "Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type) | |
81 | ) | |
82 | if not isinstance(sort_enum, type(Enum)): | |
83 | raise TypeError("Expected Graphene Enum, but got: {!r}".format(sort_enum)) | |
84 | self._registry_sort_enums[obj_type] = sort_enum | |
85 | ||
86 | def get_sort_enum_for_object_type(self, obj_type): | |
87 | return self._registry_sort_enums.get(obj_type) | |
28 | 88 | |
29 | 89 | |
30 | 90 | registry = None |
0 | from graphene.utils.get_unbound_function import get_unbound_function | |
1 | ||
2 | ||
3 | def get_custom_resolver(obj_type, orm_field_name): | |
4 | """ | |
5 | Since `graphene` will call `resolve_<field_name>` on a field only if it | |
6 | does not have a `resolver`, we need to re-implement that logic here so | |
7 | users are able to override the default resolvers that we provide. | |
8 | """ | |
9 | resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None) | |
10 | if resolver: | |
11 | return get_unbound_function(resolver) | |
12 | ||
13 | return None | |
14 | ||
15 | ||
16 | def get_attr_resolver(obj_type, model_attr): | |
17 | """ | |
18 | In order to support field renaming via `ORMField.model_attr`, | |
19 | we need to define resolver functions for each field. | |
20 | ||
21 | :param SQLAlchemyObjectType obj_type: | |
22 | :param str model_attr: the name of the SQLAlchemy attribute | |
23 | :rtype: Callable | |
24 | """ | |
25 | return lambda root, _info: getattr(root, model_attr, None) |
0 | import pytest | |
1 | from sqlalchemy import create_engine | |
2 | from sqlalchemy.orm import sessionmaker | |
3 | ||
4 | import graphene | |
5 | ||
6 | from ..converter import convert_sqlalchemy_composite | |
7 | from ..registry import reset_global_registry | |
8 | from .models import Base, CompositeFullName | |
9 | ||
10 | test_db_url = 'sqlite://' # use in-memory database for tests | |
11 | ||
12 | ||
13 | @pytest.fixture(autouse=True) | |
14 | def reset_registry(): | |
15 | reset_global_registry() | |
16 | ||
17 | # Prevent tests that implicitly depend on Reporter from raising | |
18 | # Tests that explicitly depend on this behavior should re-register a converter | |
19 | @convert_sqlalchemy_composite.register(CompositeFullName) | |
20 | def convert_composite_class(composite, registry): | |
21 | return graphene.Field(graphene.Int) | |
22 | ||
23 | ||
24 | @pytest.yield_fixture(scope="function") | |
25 | def session_factory(): | |
26 | engine = create_engine(test_db_url) | |
27 | Base.metadata.create_all(engine) | |
28 | ||
29 | yield sessionmaker(bind=engine) | |
30 | ||
31 | # SQLite in-memory db is deleted when its connection is closed. | |
32 | # https://www.sqlite.org/inmemorydb.html | |
33 | engine.dispose() | |
34 | ||
35 | ||
36 | @pytest.fixture(scope="function") | |
37 | def session(session_factory): | |
38 | return session_factory() |
1 | 1 | |
2 | 2 | import enum |
3 | 3 | |
4 | from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table | |
4 | from sqlalchemy import (Column, Date, Enum, ForeignKey, Integer, String, Table, | |
5 | func, select) | |
5 | 6 | from sqlalchemy.ext.declarative import declarative_base |
6 | from sqlalchemy.orm import mapper, relationship | |
7 | from sqlalchemy.ext.hybrid import hybrid_property | |
8 | from sqlalchemy.orm import column_property, composite, mapper, relationship | |
9 | ||
10 | PetKind = Enum("cat", "dog", name="pet_kind") | |
7 | 11 | |
8 | 12 | |
9 | class Hairkind(enum.Enum): | |
13 | class HairKind(enum.Enum): | |
10 | 14 | LONG = 'long' |
11 | 15 | SHORT = 'short' |
12 | 16 | |
31 | 35 | __tablename__ = "pets" |
32 | 36 | id = Column(Integer(), primary_key=True) |
33 | 37 | name = Column(String(30)) |
34 | pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False) | |
35 | hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False) | |
38 | pet_kind = Column(PetKind, nullable=False) | |
39 | hair_kind = Column(Enum(HairKind, name="hair_kind"), nullable=False) | |
36 | 40 | reporter_id = Column(Integer(), ForeignKey("reporters.id")) |
41 | ||
42 | ||
43 | class CompositeFullName(object): | |
44 | def __init__(self, first_name, last_name): | |
45 | self.first_name = first_name | |
46 | self.last_name = last_name | |
47 | ||
48 | def __composite_values__(self): | |
49 | return self.first_name, self.last_name | |
50 | ||
51 | def __repr__(self): | |
52 | return "{} {}".format(self.first_name, self.last_name) | |
37 | 53 | |
38 | 54 | |
39 | 55 | class Reporter(Base): |
40 | 56 | __tablename__ = "reporters" |
57 | ||
41 | 58 | id = Column(Integer(), primary_key=True) |
42 | first_name = Column(String(30)) | |
43 | last_name = Column(String(30)) | |
44 | email = Column(String()) | |
45 | pets = relationship("Pet", secondary=association_table, backref="reporters") | |
59 | first_name = Column(String(30), doc="First name") | |
60 | last_name = Column(String(30), doc="Last name") | |
61 | email = Column(String(), doc="Email") | |
62 | favorite_pet_kind = Column(PetKind) | |
63 | pets = relationship("Pet", secondary=association_table, backref="reporters", order_by="Pet.id") | |
46 | 64 | articles = relationship("Article", backref="reporter") |
47 | 65 | favorite_article = relationship("Article", uselist=False) |
48 | 66 | |
49 | # total = column_property( | |
50 | # select([ | |
51 | # func.cast(func.count(PersonInfo.id), Float) | |
52 | # ]) | |
53 | # ) | |
67 | @hybrid_property | |
68 | def hybrid_prop(self): | |
69 | return self.first_name | |
70 | ||
71 | column_prop = column_property( | |
72 | select([func.cast(func.count(id), Integer)]), doc="Column property" | |
73 | ) | |
74 | ||
75 | composite_prop = composite(CompositeFullName, first_name, last_name, doc="Composite") | |
54 | 76 | |
55 | 77 | |
56 | 78 | class Article(Base): |
0 | import contextlib | |
1 | import logging | |
2 | ||
3 | import pytest | |
4 | ||
5 | import graphene | |
6 | from graphene import relay | |
7 | ||
8 | from ..fields import (BatchSQLAlchemyConnectionField, | |
9 | default_connection_field_factory) | |
10 | from ..types import ORMField, SQLAlchemyObjectType | |
11 | from .models import Article, HairKind, Pet, Reporter | |
12 | from .utils import is_sqlalchemy_version_less_than, to_std_dicts | |
13 | ||
14 | ||
15 | class MockLoggingHandler(logging.Handler): | |
16 | """Intercept and store log messages in a list.""" | |
17 | def __init__(self, *args, **kwargs): | |
18 | self.messages = [] | |
19 | logging.Handler.__init__(self, *args, **kwargs) | |
20 | ||
21 | def emit(self, record): | |
22 | self.messages.append(record.getMessage()) | |
23 | ||
24 | ||
25 | @contextlib.contextmanager | |
26 | def mock_sqlalchemy_logging_handler(): | |
27 | logging.basicConfig() | |
28 | sql_logger = logging.getLogger('sqlalchemy.engine') | |
29 | previous_level = sql_logger.level | |
30 | ||
31 | sql_logger.setLevel(logging.INFO) | |
32 | mock_logging_handler = MockLoggingHandler() | |
33 | mock_logging_handler.setLevel(logging.INFO) | |
34 | sql_logger.addHandler(mock_logging_handler) | |
35 | ||
36 | yield mock_logging_handler | |
37 | ||
38 | sql_logger.setLevel(previous_level) | |
39 | ||
40 | ||
41 | def get_schema(): | |
42 | class ReporterType(SQLAlchemyObjectType): | |
43 | class Meta: | |
44 | model = Reporter | |
45 | interfaces = (relay.Node,) | |
46 | batching = True | |
47 | ||
48 | class ArticleType(SQLAlchemyObjectType): | |
49 | class Meta: | |
50 | model = Article | |
51 | interfaces = (relay.Node,) | |
52 | batching = True | |
53 | ||
54 | class PetType(SQLAlchemyObjectType): | |
55 | class Meta: | |
56 | model = Pet | |
57 | interfaces = (relay.Node,) | |
58 | batching = True | |
59 | ||
60 | class Query(graphene.ObjectType): | |
61 | articles = graphene.Field(graphene.List(ArticleType)) | |
62 | reporters = graphene.Field(graphene.List(ReporterType)) | |
63 | ||
64 | def resolve_articles(self, info): | |
65 | return info.context.get('session').query(Article).all() | |
66 | ||
67 | def resolve_reporters(self, info): | |
68 | return info.context.get('session').query(Reporter).all() | |
69 | ||
70 | return graphene.Schema(query=Query) | |
71 | ||
72 | ||
73 | if is_sqlalchemy_version_less_than('1.2'): | |
74 | pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) | |
75 | ||
76 | ||
77 | def test_many_to_one(session_factory): | |
78 | session = session_factory() | |
79 | ||
80 | reporter_1 = Reporter( | |
81 | first_name='Reporter_1', | |
82 | ) | |
83 | session.add(reporter_1) | |
84 | reporter_2 = Reporter( | |
85 | first_name='Reporter_2', | |
86 | ) | |
87 | session.add(reporter_2) | |
88 | ||
89 | article_1 = Article(headline='Article_1') | |
90 | article_1.reporter = reporter_1 | |
91 | session.add(article_1) | |
92 | ||
93 | article_2 = Article(headline='Article_2') | |
94 | article_2.reporter = reporter_2 | |
95 | session.add(article_2) | |
96 | ||
97 | session.commit() | |
98 | session.close() | |
99 | ||
100 | schema = get_schema() | |
101 | ||
102 | with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: | |
103 | # Starts new session to fully reset the engine / connection logging level | |
104 | session = session_factory() | |
105 | result = schema.execute(""" | |
106 | query { | |
107 | articles { | |
108 | headline | |
109 | reporter { | |
110 | firstName | |
111 | } | |
112 | } | |
113 | } | |
114 | """, context_value={"session": session}) | |
115 | messages = sqlalchemy_logging_handler.messages | |
116 | ||
117 | assert len(messages) == 5 | |
118 | ||
119 | if is_sqlalchemy_version_less_than('1.3'): | |
120 | # The batched SQL statement generated is different in 1.2.x | |
121 | # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` | |
122 | # See https://git.io/JewQu | |
123 | sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN reporters' in message] | |
124 | assert len(sql_statements) == 1 | |
125 | return | |
126 | ||
127 | assert messages == [ | |
128 | 'BEGIN (implicit)', | |
129 | ||
130 | 'SELECT articles.id AS articles_id, ' | |
131 | 'articles.headline AS articles_headline, ' | |
132 | 'articles.pub_date AS articles_pub_date, ' | |
133 | 'articles.reporter_id AS articles_reporter_id \n' | |
134 | 'FROM articles', | |
135 | '()', | |
136 | ||
137 | 'SELECT reporters.id AS reporters_id, ' | |
138 | '(SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' | |
139 | 'reporters.first_name AS reporters_first_name, ' | |
140 | 'reporters.last_name AS reporters_last_name, ' | |
141 | 'reporters.email AS reporters_email, ' | |
142 | 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' | |
143 | 'FROM reporters \n' | |
144 | 'WHERE reporters.id IN (?, ?)', | |
145 | '(1, 2)', | |
146 | ] | |
147 | ||
148 | assert not result.errors | |
149 | result = to_std_dicts(result.data) | |
150 | assert result == { | |
151 | "articles": [ | |
152 | { | |
153 | "headline": "Article_1", | |
154 | "reporter": { | |
155 | "firstName": "Reporter_1", | |
156 | }, | |
157 | }, | |
158 | { | |
159 | "headline": "Article_2", | |
160 | "reporter": { | |
161 | "firstName": "Reporter_2", | |
162 | }, | |
163 | }, | |
164 | ], | |
165 | } | |
166 | ||
167 | ||
168 | def test_one_to_one(session_factory): | |
169 | session = session_factory() | |
170 | ||
171 | reporter_1 = Reporter( | |
172 | first_name='Reporter_1', | |
173 | ) | |
174 | session.add(reporter_1) | |
175 | reporter_2 = Reporter( | |
176 | first_name='Reporter_2', | |
177 | ) | |
178 | session.add(reporter_2) | |
179 | ||
180 | article_1 = Article(headline='Article_1') | |
181 | article_1.reporter = reporter_1 | |
182 | session.add(article_1) | |
183 | ||
184 | article_2 = Article(headline='Article_2') | |
185 | article_2.reporter = reporter_2 | |
186 | session.add(article_2) | |
187 | ||
188 | session.commit() | |
189 | session.close() | |
190 | ||
191 | schema = get_schema() | |
192 | ||
193 | with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: | |
194 | # Starts new session to fully reset the engine / connection logging level | |
195 | session = session_factory() | |
196 | result = schema.execute(""" | |
197 | query { | |
198 | reporters { | |
199 | firstName | |
200 | favoriteArticle { | |
201 | headline | |
202 | } | |
203 | } | |
204 | } | |
205 | """, context_value={"session": session}) | |
206 | messages = sqlalchemy_logging_handler.messages | |
207 | ||
208 | assert len(messages) == 5 | |
209 | ||
210 | if is_sqlalchemy_version_less_than('1.3'): | |
211 | # The batched SQL statement generated is different in 1.2.x | |
212 | # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` | |
213 | # See https://git.io/JewQu | |
214 | sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] | |
215 | assert len(sql_statements) == 1 | |
216 | return | |
217 | ||
218 | assert messages == [ | |
219 | 'BEGIN (implicit)', | |
220 | ||
221 | 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' | |
222 | 'reporters.id AS reporters_id, ' | |
223 | 'reporters.first_name AS reporters_first_name, ' | |
224 | 'reporters.last_name AS reporters_last_name, ' | |
225 | 'reporters.email AS reporters_email, ' | |
226 | 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' | |
227 | 'FROM reporters', | |
228 | '()', | |
229 | ||
230 | 'SELECT articles.reporter_id AS articles_reporter_id, ' | |
231 | 'articles.id AS articles_id, ' | |
232 | 'articles.headline AS articles_headline, ' | |
233 | 'articles.pub_date AS articles_pub_date \n' | |
234 | 'FROM articles \n' | |
235 | 'WHERE articles.reporter_id IN (?, ?)', | |
236 | '(1, 2)' | |
237 | ] | |
238 | ||
239 | assert not result.errors | |
240 | result = to_std_dicts(result.data) | |
241 | assert result == { | |
242 | "reporters": [ | |
243 | { | |
244 | "firstName": "Reporter_1", | |
245 | "favoriteArticle": { | |
246 | "headline": "Article_1", | |
247 | }, | |
248 | }, | |
249 | { | |
250 | "firstName": "Reporter_2", | |
251 | "favoriteArticle": { | |
252 | "headline": "Article_2", | |
253 | }, | |
254 | }, | |
255 | ], | |
256 | } | |
257 | ||
258 | ||
259 | def test_one_to_many(session_factory): | |
260 | session = session_factory() | |
261 | ||
262 | reporter_1 = Reporter( | |
263 | first_name='Reporter_1', | |
264 | ) | |
265 | session.add(reporter_1) | |
266 | reporter_2 = Reporter( | |
267 | first_name='Reporter_2', | |
268 | ) | |
269 | session.add(reporter_2) | |
270 | ||
271 | article_1 = Article(headline='Article_1') | |
272 | article_1.reporter = reporter_1 | |
273 | session.add(article_1) | |
274 | ||
275 | article_2 = Article(headline='Article_2') | |
276 | article_2.reporter = reporter_1 | |
277 | session.add(article_2) | |
278 | ||
279 | article_3 = Article(headline='Article_3') | |
280 | article_3.reporter = reporter_2 | |
281 | session.add(article_3) | |
282 | ||
283 | article_4 = Article(headline='Article_4') | |
284 | article_4.reporter = reporter_2 | |
285 | session.add(article_4) | |
286 | ||
287 | session.commit() | |
288 | session.close() | |
289 | ||
290 | schema = get_schema() | |
291 | ||
292 | with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: | |
293 | # Starts new session to fully reset the engine / connection logging level | |
294 | session = session_factory() | |
295 | result = schema.execute(""" | |
296 | query { | |
297 | reporters { | |
298 | firstName | |
299 | articles(first: 2) { | |
300 | edges { | |
301 | node { | |
302 | headline | |
303 | } | |
304 | } | |
305 | } | |
306 | } | |
307 | } | |
308 | """, context_value={"session": session}) | |
309 | messages = sqlalchemy_logging_handler.messages | |
310 | ||
311 | assert len(messages) == 5 | |
312 | ||
313 | if is_sqlalchemy_version_less_than('1.3'): | |
314 | # The batched SQL statement generated is different in 1.2.x | |
315 | # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` | |
316 | # See https://git.io/JewQu | |
317 | sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] | |
318 | assert len(sql_statements) == 1 | |
319 | return | |
320 | ||
321 | assert messages == [ | |
322 | 'BEGIN (implicit)', | |
323 | ||
324 | 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' | |
325 | 'reporters.id AS reporters_id, ' | |
326 | 'reporters.first_name AS reporters_first_name, ' | |
327 | 'reporters.last_name AS reporters_last_name, ' | |
328 | 'reporters.email AS reporters_email, ' | |
329 | 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' | |
330 | 'FROM reporters', | |
331 | '()', | |
332 | ||
333 | 'SELECT articles.reporter_id AS articles_reporter_id, ' | |
334 | 'articles.id AS articles_id, ' | |
335 | 'articles.headline AS articles_headline, ' | |
336 | 'articles.pub_date AS articles_pub_date \n' | |
337 | 'FROM articles \n' | |
338 | 'WHERE articles.reporter_id IN (?, ?)', | |
339 | '(1, 2)' | |
340 | ] | |
341 | ||
342 | assert not result.errors | |
343 | result = to_std_dicts(result.data) | |
344 | assert result == { | |
345 | "reporters": [ | |
346 | { | |
347 | "firstName": "Reporter_1", | |
348 | "articles": { | |
349 | "edges": [ | |
350 | { | |
351 | "node": { | |
352 | "headline": "Article_1", | |
353 | }, | |
354 | }, | |
355 | { | |
356 | "node": { | |
357 | "headline": "Article_2", | |
358 | }, | |
359 | }, | |
360 | ], | |
361 | }, | |
362 | }, | |
363 | { | |
364 | "firstName": "Reporter_2", | |
365 | "articles": { | |
366 | "edges": [ | |
367 | { | |
368 | "node": { | |
369 | "headline": "Article_3", | |
370 | }, | |
371 | }, | |
372 | { | |
373 | "node": { | |
374 | "headline": "Article_4", | |
375 | }, | |
376 | }, | |
377 | ], | |
378 | }, | |
379 | }, | |
380 | ], | |
381 | } | |
382 | ||
383 | ||
384 | def test_many_to_many(session_factory): | |
385 | session = session_factory() | |
386 | ||
387 | reporter_1 = Reporter( | |
388 | first_name='Reporter_1', | |
389 | ) | |
390 | session.add(reporter_1) | |
391 | reporter_2 = Reporter( | |
392 | first_name='Reporter_2', | |
393 | ) | |
394 | session.add(reporter_2) | |
395 | ||
396 | pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) | |
397 | session.add(pet_1) | |
398 | ||
399 | pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) | |
400 | session.add(pet_2) | |
401 | ||
402 | reporter_1.pets.append(pet_1) | |
403 | reporter_1.pets.append(pet_2) | |
404 | ||
405 | pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) | |
406 | session.add(pet_3) | |
407 | ||
408 | pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) | |
409 | session.add(pet_4) | |
410 | ||
411 | reporter_2.pets.append(pet_3) | |
412 | reporter_2.pets.append(pet_4) | |
413 | ||
414 | session.commit() | |
415 | session.close() | |
416 | ||
417 | schema = get_schema() | |
418 | ||
419 | with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: | |
420 | # Starts new session to fully reset the engine / connection logging level | |
421 | session = session_factory() | |
422 | result = schema.execute(""" | |
423 | query { | |
424 | reporters { | |
425 | firstName | |
426 | pets(first: 2) { | |
427 | edges { | |
428 | node { | |
429 | name | |
430 | } | |
431 | } | |
432 | } | |
433 | } | |
434 | } | |
435 | """, context_value={"session": session}) | |
436 | messages = sqlalchemy_logging_handler.messages | |
437 | ||
438 | assert len(messages) == 5 | |
439 | ||
440 | if is_sqlalchemy_version_less_than('1.3'): | |
441 | # The batched SQL statement generated is different in 1.2.x | |
442 | # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` | |
443 | # See https://git.io/JewQu | |
444 | sql_statements = [message for message in messages if 'SELECT' in message and 'JOIN pets' in message] | |
445 | assert len(sql_statements) == 1 | |
446 | return | |
447 | ||
448 | assert messages == [ | |
449 | 'BEGIN (implicit)', | |
450 | ||
451 | 'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, ' | |
452 | 'reporters.id AS reporters_id, ' | |
453 | 'reporters.first_name AS reporters_first_name, ' | |
454 | 'reporters.last_name AS reporters_last_name, ' | |
455 | 'reporters.email AS reporters_email, ' | |
456 | 'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n' | |
457 | 'FROM reporters', | |
458 | '()', | |
459 | ||
460 | 'SELECT reporters_1.id AS reporters_1_id, ' | |
461 | 'pets.id AS pets_id, ' | |
462 | 'pets.name AS pets_name, ' | |
463 | 'pets.pet_kind AS pets_pet_kind, ' | |
464 | 'pets.hair_kind AS pets_hair_kind, ' | |
465 | 'pets.reporter_id AS pets_reporter_id \n' | |
466 | 'FROM reporters AS reporters_1 ' | |
467 | 'JOIN association AS association_1 ON reporters_1.id = association_1.reporter_id ' | |
468 | 'JOIN pets ON pets.id = association_1.pet_id \n' | |
469 | 'WHERE reporters_1.id IN (?, ?) ' | |
470 | 'ORDER BY pets.id', | |
471 | '(1, 2)' | |
472 | ] | |
473 | ||
474 | assert not result.errors | |
475 | result = to_std_dicts(result.data) | |
476 | assert result == { | |
477 | "reporters": [ | |
478 | { | |
479 | "firstName": "Reporter_1", | |
480 | "pets": { | |
481 | "edges": [ | |
482 | { | |
483 | "node": { | |
484 | "name": "Pet_1", | |
485 | }, | |
486 | }, | |
487 | { | |
488 | "node": { | |
489 | "name": "Pet_2", | |
490 | }, | |
491 | }, | |
492 | ], | |
493 | }, | |
494 | }, | |
495 | { | |
496 | "firstName": "Reporter_2", | |
497 | "pets": { | |
498 | "edges": [ | |
499 | { | |
500 | "node": { | |
501 | "name": "Pet_3", | |
502 | }, | |
503 | }, | |
504 | { | |
505 | "node": { | |
506 | "name": "Pet_4", | |
507 | }, | |
508 | }, | |
509 | ], | |
510 | }, | |
511 | }, | |
512 | ], | |
513 | } | |
514 | ||
515 | ||
516 | def test_disable_batching_via_ormfield(session_factory): | |
517 | session = session_factory() | |
518 | reporter_1 = Reporter(first_name='Reporter_1') | |
519 | session.add(reporter_1) | |
520 | reporter_2 = Reporter(first_name='Reporter_2') | |
521 | session.add(reporter_2) | |
522 | session.commit() | |
523 | session.close() | |
524 | ||
525 | class ReporterType(SQLAlchemyObjectType): | |
526 | class Meta: | |
527 | model = Reporter | |
528 | interfaces = (relay.Node,) | |
529 | batching = True | |
530 | ||
531 | favorite_article = ORMField(batching=False) | |
532 | articles = ORMField(batching=False) | |
533 | ||
534 | class ArticleType(SQLAlchemyObjectType): | |
535 | class Meta: | |
536 | model = Article | |
537 | interfaces = (relay.Node,) | |
538 | ||
539 | class Query(graphene.ObjectType): | |
540 | reporters = graphene.Field(graphene.List(ReporterType)) | |
541 | ||
542 | def resolve_reporters(self, info): | |
543 | return info.context.get('session').query(Reporter).all() | |
544 | ||
545 | schema = graphene.Schema(query=Query) | |
546 | ||
547 | # Test one-to-one and many-to-one relationships | |
548 | with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: | |
549 | # Starts new session to fully reset the engine / connection logging level | |
550 | session = session_factory() | |
551 | schema.execute(""" | |
552 | query { | |
553 | reporters { | |
554 | favoriteArticle { | |
555 | headline | |
556 | } | |
557 | } | |
558 | } | |
559 | """, context_value={"session": session}) | |
560 | messages = sqlalchemy_logging_handler.messages | |
561 | ||
562 | select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] | |
563 | assert len(select_statements) == 2 | |
564 | ||
565 | # Test one-to-many and many-to-many relationships | |
566 | with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: | |
567 | # Starts new session to fully reset the engine / connection logging level | |
568 | session = session_factory() | |
569 | schema.execute(""" | |
570 | query { | |
571 | reporters { | |
572 | articles { | |
573 | edges { | |
574 | node { | |
575 | headline | |
576 | } | |
577 | } | |
578 | } | |
579 | } | |
580 | } | |
581 | """, context_value={"session": session}) | |
582 | messages = sqlalchemy_logging_handler.messages | |
583 | ||
584 | select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] | |
585 | assert len(select_statements) == 2 | |
586 | ||
587 | ||
588 | def test_connection_factory_field_overrides_batching_is_false(session_factory): | |
589 | session = session_factory() | |
590 | reporter_1 = Reporter(first_name='Reporter_1') | |
591 | session.add(reporter_1) | |
592 | reporter_2 = Reporter(first_name='Reporter_2') | |
593 | session.add(reporter_2) | |
594 | session.commit() | |
595 | session.close() | |
596 | ||
597 | class ReporterType(SQLAlchemyObjectType): | |
598 | class Meta: | |
599 | model = Reporter | |
600 | interfaces = (relay.Node,) | |
601 | batching = False | |
602 | connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship | |
603 | ||
604 | articles = ORMField(batching=False) | |
605 | ||
606 | class ArticleType(SQLAlchemyObjectType): | |
607 | class Meta: | |
608 | model = Article | |
609 | interfaces = (relay.Node,) | |
610 | ||
611 | class Query(graphene.ObjectType): | |
612 | reporters = graphene.Field(graphene.List(ReporterType)) | |
613 | ||
614 | def resolve_reporters(self, info): | |
615 | return info.context.get('session').query(Reporter).all() | |
616 | ||
617 | schema = graphene.Schema(query=Query) | |
618 | ||
619 | with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: | |
620 | # Starts new session to fully reset the engine / connection logging level | |
621 | session = session_factory() | |
622 | schema.execute(""" | |
623 | query { | |
624 | reporters { | |
625 | articles { | |
626 | edges { | |
627 | node { | |
628 | headline | |
629 | } | |
630 | } | |
631 | } | |
632 | } | |
633 | } | |
634 | """, context_value={"session": session}) | |
635 | messages = sqlalchemy_logging_handler.messages | |
636 | ||
637 | if is_sqlalchemy_version_less_than('1.3'): | |
638 | # The batched SQL statement generated is different in 1.2.x | |
639 | # SQLAlchemy 1.3+ optimizes out a JOIN statement in `selectin` | |
640 | # See https://git.io/JewQu | |
641 | select_statements = [message for message in messages if 'SELECT' in message and 'JOIN articles' in message] | |
642 | else: | |
643 | select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] | |
644 | assert len(select_statements) == 1 | |
645 | ||
646 | ||
647 | def test_connection_factory_field_overrides_batching_is_true(session_factory): | |
648 | session = session_factory() | |
649 | reporter_1 = Reporter(first_name='Reporter_1') | |
650 | session.add(reporter_1) | |
651 | reporter_2 = Reporter(first_name='Reporter_2') | |
652 | session.add(reporter_2) | |
653 | session.commit() | |
654 | session.close() | |
655 | ||
656 | class ReporterType(SQLAlchemyObjectType): | |
657 | class Meta: | |
658 | model = Reporter | |
659 | interfaces = (relay.Node,) | |
660 | batching = True | |
661 | connection_field_factory = default_connection_field_factory | |
662 | ||
663 | articles = ORMField(batching=True) | |
664 | ||
665 | class ArticleType(SQLAlchemyObjectType): | |
666 | class Meta: | |
667 | model = Article | |
668 | interfaces = (relay.Node,) | |
669 | ||
670 | class Query(graphene.ObjectType): | |
671 | reporters = graphene.Field(graphene.List(ReporterType)) | |
672 | ||
673 | def resolve_reporters(self, info): | |
674 | return info.context.get('session').query(Reporter).all() | |
675 | ||
676 | schema = graphene.Schema(query=Query) | |
677 | ||
678 | with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler: | |
679 | # Starts new session to fully reset the engine / connection logging level | |
680 | session = session_factory() | |
681 | schema.execute(""" | |
682 | query { | |
683 | reporters { | |
684 | articles { | |
685 | edges { | |
686 | node { | |
687 | headline | |
688 | } | |
689 | } | |
690 | } | |
691 | } | |
692 | } | |
693 | """, context_value={"session": session}) | |
694 | messages = sqlalchemy_logging_handler.messages | |
695 | ||
696 | select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message] | |
697 | assert len(select_statements) == 2 |
0 | import pytest | |
1 | from graphql.backend import GraphQLCachedBackend, GraphQLCoreBackend | |
2 | ||
3 | import graphene | |
4 | from graphene import relay | |
5 | ||
6 | from ..fields import BatchSQLAlchemyConnectionField | |
7 | from ..types import SQLAlchemyObjectType | |
8 | from .models import Article, HairKind, Pet, Reporter | |
9 | from .utils import is_sqlalchemy_version_less_than | |
10 | ||
11 | if is_sqlalchemy_version_less_than('1.2'): | |
12 | pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True) | |
13 | ||
14 | ||
15 | def get_schema(): | |
16 | class ReporterType(SQLAlchemyObjectType): | |
17 | class Meta: | |
18 | model = Reporter | |
19 | interfaces = (relay.Node,) | |
20 | connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship | |
21 | ||
22 | class ArticleType(SQLAlchemyObjectType): | |
23 | class Meta: | |
24 | model = Article | |
25 | interfaces = (relay.Node,) | |
26 | connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship | |
27 | ||
28 | class PetType(SQLAlchemyObjectType): | |
29 | class Meta: | |
30 | model = Pet | |
31 | interfaces = (relay.Node,) | |
32 | connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship | |
33 | ||
34 | class Query(graphene.ObjectType): | |
35 | articles = graphene.Field(graphene.List(ArticleType)) | |
36 | reporters = graphene.Field(graphene.List(ReporterType)) | |
37 | ||
38 | def resolve_articles(self, info): | |
39 | return info.context.get('session').query(Article).all() | |
40 | ||
41 | def resolve_reporters(self, info): | |
42 | return info.context.get('session').query(Reporter).all() | |
43 | ||
44 | return graphene.Schema(query=Query) | |
45 | ||
46 | ||
47 | def benchmark_query(session_factory, benchmark, query): | |
48 | schema = get_schema() | |
49 | cached_backend = GraphQLCachedBackend(GraphQLCoreBackend()) | |
50 | cached_backend.document_from_string(schema, query) # Prime cache | |
51 | ||
52 | @benchmark | |
53 | def execute_query(): | |
54 | result = schema.execute( | |
55 | query, | |
56 | context_value={"session": session_factory()}, | |
57 | backend=cached_backend, | |
58 | ) | |
59 | assert not result.errors | |
60 | ||
61 | ||
62 | def test_one_to_one(session_factory, benchmark): | |
63 | session = session_factory() | |
64 | ||
65 | reporter_1 = Reporter( | |
66 | first_name='Reporter_1', | |
67 | ) | |
68 | session.add(reporter_1) | |
69 | reporter_2 = Reporter( | |
70 | first_name='Reporter_2', | |
71 | ) | |
72 | session.add(reporter_2) | |
73 | ||
74 | article_1 = Article(headline='Article_1') | |
75 | article_1.reporter = reporter_1 | |
76 | session.add(article_1) | |
77 | ||
78 | article_2 = Article(headline='Article_2') | |
79 | article_2.reporter = reporter_2 | |
80 | session.add(article_2) | |
81 | ||
82 | session.commit() | |
83 | session.close() | |
84 | ||
85 | benchmark_query(session_factory, benchmark, """ | |
86 | query { | |
87 | reporters { | |
88 | firstName | |
89 | favoriteArticle { | |
90 | headline | |
91 | } | |
92 | } | |
93 | } | |
94 | """) | |
95 | ||
96 | ||
97 | def test_many_to_one(session_factory, benchmark): | |
98 | session = session_factory() | |
99 | ||
100 | reporter_1 = Reporter( | |
101 | first_name='Reporter_1', | |
102 | ) | |
103 | session.add(reporter_1) | |
104 | reporter_2 = Reporter( | |
105 | first_name='Reporter_2', | |
106 | ) | |
107 | session.add(reporter_2) | |
108 | ||
109 | article_1 = Article(headline='Article_1') | |
110 | article_1.reporter = reporter_1 | |
111 | session.add(article_1) | |
112 | ||
113 | article_2 = Article(headline='Article_2') | |
114 | article_2.reporter = reporter_2 | |
115 | session.add(article_2) | |
116 | ||
117 | session.commit() | |
118 | session.close() | |
119 | ||
120 | benchmark_query(session_factory, benchmark, """ | |
121 | query { | |
122 | articles { | |
123 | headline | |
124 | reporter { | |
125 | firstName | |
126 | } | |
127 | } | |
128 | } | |
129 | """) | |
130 | ||
131 | ||
132 | def test_one_to_many(session_factory, benchmark): | |
133 | session = session_factory() | |
134 | ||
135 | reporter_1 = Reporter( | |
136 | first_name='Reporter_1', | |
137 | ) | |
138 | session.add(reporter_1) | |
139 | reporter_2 = Reporter( | |
140 | first_name='Reporter_2', | |
141 | ) | |
142 | session.add(reporter_2) | |
143 | ||
144 | article_1 = Article(headline='Article_1') | |
145 | article_1.reporter = reporter_1 | |
146 | session.add(article_1) | |
147 | ||
148 | article_2 = Article(headline='Article_2') | |
149 | article_2.reporter = reporter_1 | |
150 | session.add(article_2) | |
151 | ||
152 | article_3 = Article(headline='Article_3') | |
153 | article_3.reporter = reporter_2 | |
154 | session.add(article_3) | |
155 | ||
156 | article_4 = Article(headline='Article_4') | |
157 | article_4.reporter = reporter_2 | |
158 | session.add(article_4) | |
159 | ||
160 | session.commit() | |
161 | session.close() | |
162 | ||
163 | benchmark_query(session_factory, benchmark, """ | |
164 | query { | |
165 | reporters { | |
166 | firstName | |
167 | articles(first: 2) { | |
168 | edges { | |
169 | node { | |
170 | headline | |
171 | } | |
172 | } | |
173 | } | |
174 | } | |
175 | } | |
176 | """) | |
177 | ||
178 | ||
179 | def test_many_to_many(session_factory, benchmark): | |
180 | session = session_factory() | |
181 | ||
182 | reporter_1 = Reporter( | |
183 | first_name='Reporter_1', | |
184 | ) | |
185 | session.add(reporter_1) | |
186 | reporter_2 = Reporter( | |
187 | first_name='Reporter_2', | |
188 | ) | |
189 | session.add(reporter_2) | |
190 | ||
191 | pet_1 = Pet(name='Pet_1', pet_kind='cat', hair_kind=HairKind.LONG) | |
192 | session.add(pet_1) | |
193 | ||
194 | pet_2 = Pet(name='Pet_2', pet_kind='cat', hair_kind=HairKind.LONG) | |
195 | session.add(pet_2) | |
196 | ||
197 | reporter_1.pets.append(pet_1) | |
198 | reporter_1.pets.append(pet_2) | |
199 | ||
200 | pet_3 = Pet(name='Pet_3', pet_kind='cat', hair_kind=HairKind.LONG) | |
201 | session.add(pet_3) | |
202 | ||
203 | pet_4 = Pet(name='Pet_4', pet_kind='cat', hair_kind=HairKind.LONG) | |
204 | session.add(pet_4) | |
205 | ||
206 | reporter_2.pets.append(pet_3) | |
207 | reporter_2.pets.append(pet_4) | |
208 | ||
209 | session.commit() | |
210 | session.close() | |
211 | ||
212 | benchmark_query(session_factory, benchmark, """ | |
213 | query { | |
214 | reporters { | |
215 | firstName | |
216 | pets(first: 2) { | |
217 | edges { | |
218 | node { | |
219 | name | |
220 | } | |
221 | } | |
222 | } | |
223 | } | |
224 | } | |
225 | """) |
0 | 0 | import enum |
1 | 1 | |
2 | from py.test import raises | |
3 | from sqlalchemy import Column, Table, case, func, select, types | |
2 | import pytest | |
3 | from sqlalchemy import Column, func, select, types | |
4 | 4 | from sqlalchemy.dialects import postgresql |
5 | 5 | from sqlalchemy.ext.declarative import declarative_base |
6 | from sqlalchemy.inspection import inspect | |
6 | 7 | from sqlalchemy.orm import column_property, composite |
7 | from sqlalchemy.sql.elements import Label | |
8 | 8 | from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType |
9 | 9 | |
10 | 10 | import graphene |
17 | 17 | convert_sqlalchemy_relationship) |
18 | 18 | from ..fields import (UnsortedSQLAlchemyConnectionField, |
19 | 19 | default_connection_field_factory) |
20 | from ..registry import Registry | |
20 | from ..registry import Registry, get_global_registry | |
21 | 21 | from ..types import SQLAlchemyObjectType |
22 | from .models import Article, Pet, Reporter | |
23 | ||
24 | ||
25 | def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs): | |
26 | column = Column(sqlalchemy_type, doc="Custom Help Text", **kwargs) | |
27 | graphene_type = convert_sqlalchemy_column(column) | |
28 | assert isinstance(graphene_type, graphene_field) | |
29 | field = ( | |
30 | graphene_type | |
31 | if isinstance(graphene_type, graphene.Field) | |
32 | else graphene_type.Field() | |
33 | ) | |
34 | assert field.description == "Custom Help Text" | |
35 | return field | |
36 | ||
37 | ||
38 | def assert_composite_conversion( | |
39 | composite_class, composite_columns, graphene_field, registry, **kwargs | |
40 | ): | |
41 | composite_column = composite( | |
42 | composite_class, *composite_columns, doc="Custom Help Text", **kwargs | |
43 | ) | |
44 | graphene_type = convert_sqlalchemy_composite(composite_column, registry) | |
45 | assert isinstance(graphene_type, graphene_field) | |
46 | field = graphene_type.Field() | |
47 | # SQLAlchemy currently does not persist the doc onto the column, even though | |
48 | # the documentation says it does.... | |
49 | # assert field.description == 'Custom Help Text' | |
50 | return field | |
22 | from .models import Article, CompositeFullName, Pet, Reporter | |
23 | ||
24 | ||
25 | def mock_resolver(): | |
26 | pass | |
27 | ||
28 | ||
29 | def get_field(sqlalchemy_type, **column_kwargs): | |
30 | class Model(declarative_base()): | |
31 | __tablename__ = 'model' | |
32 | id_ = Column(types.Integer, primary_key=True) | |
33 | column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs) | |
34 | ||
35 | column_prop = inspect(Model).column_attrs['column'] | |
36 | return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) | |
37 | ||
38 | ||
39 | def get_field_from_column(column_): | |
40 | class Model(declarative_base()): | |
41 | __tablename__ = 'model' | |
42 | id_ = Column(types.Integer, primary_key=True) | |
43 | column = column_ | |
44 | ||
45 | column_prop = inspect(Model).column_attrs['column'] | |
46 | return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) | |
51 | 47 | |
52 | 48 | |
53 | 49 | def test_should_unknown_sqlalchemy_field_raise_exception(): |
54 | with raises(Exception) as excinfo: | |
55 | convert_sqlalchemy_column(None) | |
56 | assert "Don't know how to convert the SQLAlchemy field" in str(excinfo.value) | |
50 | re_err = "Don't know how to convert the SQLAlchemy field" | |
51 | with pytest.raises(Exception, match=re_err): | |
52 | # support legacy Binary type and subsequent LargeBinary | |
53 | get_field(getattr(types, 'LargeBinary', types.Binary)()) | |
57 | 54 | |
58 | 55 | |
59 | 56 | def test_should_date_convert_string(): |
60 | assert_column_conversion(types.Date(), graphene.String) | |
61 | ||
62 | ||
63 | def test_should_datetime_convert_string(): | |
64 | assert_column_conversion(types.DateTime(), DateTime) | |
57 | assert get_field(types.Date()).type == graphene.String | |
58 | ||
59 | ||
60 | def test_should_datetime_convert_datetime(): | |
61 | assert get_field(types.DateTime()).type == DateTime | |
65 | 62 | |
66 | 63 | |
67 | 64 | def test_should_time_convert_string(): |
68 | assert_column_conversion(types.Time(), graphene.String) | |
65 | assert get_field(types.Time()).type == graphene.String | |
69 | 66 | |
70 | 67 | |
71 | 68 | def test_should_string_convert_string(): |
72 | assert_column_conversion(types.String(), graphene.String) | |
69 | assert get_field(types.String()).type == graphene.String | |
73 | 70 | |
74 | 71 | |
75 | 72 | def test_should_text_convert_string(): |
76 | assert_column_conversion(types.Text(), graphene.String) | |
73 | assert get_field(types.Text()).type == graphene.String | |
77 | 74 | |
78 | 75 | |
79 | 76 | def test_should_unicode_convert_string(): |
80 | assert_column_conversion(types.Unicode(), graphene.String) | |
77 | assert get_field(types.Unicode()).type == graphene.String | |
81 | 78 | |
82 | 79 | |
83 | 80 | def test_should_unicodetext_convert_string(): |
84 | assert_column_conversion(types.UnicodeText(), graphene.String) | |
81 | assert get_field(types.UnicodeText()).type == graphene.String | |
85 | 82 | |
86 | 83 | |
87 | 84 | def test_should_enum_convert_enum(): |
88 | field = assert_column_conversion( | |
89 | types.Enum(enum.Enum("one", "two")), graphene.Field | |
90 | ) | |
85 | field = get_field(types.Enum(enum.Enum("TwoNumbers", ("one", "two")))) | |
91 | 86 | field_type = field.type() |
92 | 87 | assert isinstance(field_type, graphene.Enum) |
93 | assert hasattr(field_type, "two") | |
94 | field = assert_column_conversion( | |
95 | types.Enum("one", "two", name="two_numbers"), graphene.Field | |
96 | ) | |
88 | assert field_type._meta.name == "TwoNumbers" | |
89 | assert hasattr(field_type, "ONE") | |
90 | assert not hasattr(field_type, "one") | |
91 | assert hasattr(field_type, "TWO") | |
92 | assert not hasattr(field_type, "two") | |
93 | ||
94 | field = get_field(types.Enum("one", "two", name="two_numbers")) | |
97 | 95 | field_type = field.type() |
98 | assert field_type.__class__.__name__ == "two_numbers" | |
99 | 96 | assert isinstance(field_type, graphene.Enum) |
100 | assert hasattr(field_type, "two") | |
97 | assert field_type._meta.name == "TwoNumbers" | |
98 | assert hasattr(field_type, "ONE") | |
99 | assert not hasattr(field_type, "one") | |
100 | assert hasattr(field_type, "TWO") | |
101 | assert not hasattr(field_type, "two") | |
102 | ||
103 | ||
104 | def test_should_not_enum_convert_enum_without_name(): | |
105 | field = get_field(types.Enum("one", "two")) | |
106 | re_err = r"No type name specified for Enum\('one', 'two'\)" | |
107 | with pytest.raises(TypeError, match=re_err): | |
108 | field.type() | |
101 | 109 | |
102 | 110 | |
103 | 111 | def test_should_small_integer_convert_int(): |
104 | assert_column_conversion(types.SmallInteger(), graphene.Int) | |
112 | assert get_field(types.SmallInteger()).type == graphene.Int | |
105 | 113 | |
106 | 114 | |
107 | 115 | def test_should_big_integer_convert_int(): |
108 | assert_column_conversion(types.BigInteger(), graphene.Float) | |
116 | assert get_field(types.BigInteger()).type == graphene.Float | |
109 | 117 | |
110 | 118 | |
111 | 119 | def test_should_integer_convert_int(): |
112 | assert_column_conversion(types.Integer(), graphene.Int) | |
113 | ||
114 | ||
115 | def test_should_integer_convert_id(): | |
116 | assert_column_conversion(types.Integer(), graphene.ID, primary_key=True) | |
120 | assert get_field(types.Integer()).type == graphene.Int | |
121 | ||
122 | ||
123 | def test_should_primary_integer_convert_id(): | |
124 | assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull(graphene.ID) | |
117 | 125 | |
118 | 126 | |
119 | 127 | def test_should_boolean_convert_boolean(): |
120 | assert_column_conversion(types.Boolean(), graphene.Boolean) | |
128 | assert get_field(types.Boolean()).type == graphene.Boolean | |
121 | 129 | |
122 | 130 | |
123 | 131 | def test_should_float_convert_float(): |
124 | assert_column_conversion(types.Float(), graphene.Float) | |
132 | assert get_field(types.Float()).type == graphene.Float | |
125 | 133 | |
126 | 134 | |
127 | 135 | def test_should_numeric_convert_float(): |
128 | assert_column_conversion(types.Numeric(), graphene.Float) | |
129 | ||
130 | ||
131 | def test_should_label_convert_string(): | |
132 | label = Label("label_test", case([], else_="foo"), type_=types.Unicode()) | |
133 | graphene_type = convert_sqlalchemy_column(label) | |
134 | assert isinstance(graphene_type, graphene.String) | |
135 | ||
136 | ||
137 | def test_should_label_convert_int(): | |
138 | label = Label("int_label_test", case([], else_="foo"), type_=types.Integer()) | |
139 | graphene_type = convert_sqlalchemy_column(label) | |
140 | assert isinstance(graphene_type, graphene.Int) | |
136 | assert get_field(types.Numeric()).type == graphene.Float | |
141 | 137 | |
142 | 138 | |
143 | 139 | def test_should_choice_convert_enum(): |
144 | TYPES = [(u"es", u"Spanish"), (u"en", u"English")] | |
145 | column = Column(ChoiceType(TYPES), doc="Language", name="language") | |
146 | Base = declarative_base() | |
147 | ||
148 | Table("translatedmodel", Base.metadata, column) | |
149 | graphene_type = convert_sqlalchemy_column(column) | |
140 | field = get_field(ChoiceType([(u"es", u"Spanish"), (u"en", u"English")])) | |
141 | graphene_type = field.type | |
150 | 142 | assert issubclass(graphene_type, graphene.Enum) |
151 | assert graphene_type._meta.name == "TRANSLATEDMODEL_LANGUAGE" | |
152 | assert graphene_type._meta.description == "Language" | |
143 | assert graphene_type._meta.name == "MODEL_COLUMN" | |
153 | 144 | assert graphene_type._meta.enum.__members__["es"].value == "Spanish" |
154 | 145 | assert graphene_type._meta.enum.__members__["en"].value == "English" |
155 | 146 | |
156 | 147 | |
148 | def test_should_enum_choice_convert_enum(): | |
149 | class TestEnum(enum.Enum): | |
150 | es = u"Spanish" | |
151 | en = u"English" | |
152 | ||
153 | field = get_field(ChoiceType(TestEnum, impl=types.String())) | |
154 | graphene_type = field.type | |
155 | assert issubclass(graphene_type, graphene.Enum) | |
156 | assert graphene_type._meta.name == "MODEL_COLUMN" | |
157 | assert graphene_type._meta.enum.__members__["es"].value == "Spanish" | |
158 | assert graphene_type._meta.enum.__members__["en"].value == "English" | |
159 | ||
160 | ||
161 | def test_should_intenum_choice_convert_enum(): | |
162 | class TestEnum(enum.IntEnum): | |
163 | one = 1 | |
164 | two = 2 | |
165 | ||
166 | field = get_field(ChoiceType(TestEnum, impl=types.String())) | |
167 | graphene_type = field.type | |
168 | assert issubclass(graphene_type, graphene.Enum) | |
169 | assert graphene_type._meta.name == "MODEL_COLUMN" | |
170 | assert graphene_type._meta.enum.__members__["one"].value == 1 | |
171 | assert graphene_type._meta.enum.__members__["two"].value == 2 | |
172 | ||
173 | ||
157 | 174 | def test_should_columproperty_convert(): |
158 | ||
159 | Base = declarative_base() | |
160 | ||
161 | class Test(Base): | |
162 | __tablename__ = "test" | |
163 | id = Column(types.Integer, primary_key=True) | |
164 | column = column_property( | |
165 | select([func.sum(func.cast(id, types.Integer))]).where(id == 1) | |
166 | ) | |
167 | ||
168 | graphene_type = convert_sqlalchemy_column(Test.column) | |
169 | assert not graphene_type.kwargs["required"] | |
175 | field = get_field_from_column(column_property( | |
176 | select([func.sum(func.cast(id, types.Integer))]).where(id == 1) | |
177 | )) | |
178 | ||
179 | assert field.type == graphene.Int | |
170 | 180 | |
171 | 181 | |
172 | 182 | def test_should_scalar_list_convert_list(): |
173 | assert_column_conversion(ScalarListType(), graphene.List) | |
183 | field = get_field(ScalarListType()) | |
184 | assert isinstance(field.type, graphene.List) | |
185 | assert field.type.of_type == graphene.String | |
174 | 186 | |
175 | 187 | |
176 | 188 | def test_should_jsontype_convert_jsonstring(): |
177 | assert_column_conversion(JSONType(), JSONString) | |
189 | assert get_field(JSONType()).type == JSONString | |
178 | 190 | |
179 | 191 | |
180 | 192 | def test_should_manytomany_convert_connectionorlist(): |
181 | registry = Registry() | |
182 | dynamic_field = convert_sqlalchemy_relationship( | |
183 | Reporter.pets.property, registry, default_connection_field_factory | |
193 | class A(SQLAlchemyObjectType): | |
194 | class Meta: | |
195 | model = Article | |
196 | ||
197 | dynamic_field = convert_sqlalchemy_relationship( | |
198 | Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', | |
184 | 199 | ) |
185 | 200 | assert isinstance(dynamic_field, graphene.Dynamic) |
186 | 201 | assert not dynamic_field.get_type() |
192 | 207 | model = Pet |
193 | 208 | |
194 | 209 | dynamic_field = convert_sqlalchemy_relationship( |
195 | Reporter.pets.property, A._meta.registry, default_connection_field_factory | |
210 | Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', | |
196 | 211 | ) |
197 | 212 | assert isinstance(dynamic_field, graphene.Dynamic) |
198 | 213 | graphene_type = dynamic_field.get_type() |
208 | 223 | interfaces = (Node,) |
209 | 224 | |
210 | 225 | dynamic_field = convert_sqlalchemy_relationship( |
211 | Reporter.pets.property, A._meta.registry, default_connection_field_factory | |
226 | Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', | |
212 | 227 | ) |
213 | 228 | assert isinstance(dynamic_field, graphene.Dynamic) |
214 | 229 | assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) |
215 | 230 | |
216 | 231 | |
217 | 232 | def test_should_manytoone_convert_connectionorlist(): |
218 | registry = Registry() | |
219 | dynamic_field = convert_sqlalchemy_relationship( | |
220 | Article.reporter.property, registry, default_connection_field_factory | |
233 | class A(SQLAlchemyObjectType): | |
234 | class Meta: | |
235 | model = Article | |
236 | ||
237 | dynamic_field = convert_sqlalchemy_relationship( | |
238 | Reporter.pets.property, A, default_connection_field_factory, True, 'orm_field_name', | |
221 | 239 | ) |
222 | 240 | assert isinstance(dynamic_field, graphene.Dynamic) |
223 | 241 | assert not dynamic_field.get_type() |
229 | 247 | model = Reporter |
230 | 248 | |
231 | 249 | dynamic_field = convert_sqlalchemy_relationship( |
232 | Article.reporter.property, A._meta.registry, default_connection_field_factory | |
250 | Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', | |
233 | 251 | ) |
234 | 252 | assert isinstance(dynamic_field, graphene.Dynamic) |
235 | 253 | graphene_type = dynamic_field.get_type() |
244 | 262 | interfaces = (Node,) |
245 | 263 | |
246 | 264 | dynamic_field = convert_sqlalchemy_relationship( |
247 | Article.reporter.property, A._meta.registry, default_connection_field_factory | |
265 | Article.reporter.property, A, default_connection_field_factory, True, 'orm_field_name', | |
248 | 266 | ) |
249 | 267 | assert isinstance(dynamic_field, graphene.Dynamic) |
250 | 268 | graphene_type = dynamic_field.get_type() |
259 | 277 | interfaces = (Node,) |
260 | 278 | |
261 | 279 | dynamic_field = convert_sqlalchemy_relationship( |
262 | Reporter.favorite_article.property, A._meta.registry, default_connection_field_factory | |
280 | Reporter.favorite_article.property, A, default_connection_field_factory, True, 'orm_field_name', | |
263 | 281 | ) |
264 | 282 | assert isinstance(dynamic_field, graphene.Dynamic) |
265 | 283 | graphene_type = dynamic_field.get_type() |
268 | 286 | |
269 | 287 | |
270 | 288 | def test_should_postgresql_uuid_convert(): |
271 | assert_column_conversion(postgresql.UUID(), graphene.String) | |
289 | assert get_field(postgresql.UUID()).type == graphene.String | |
272 | 290 | |
273 | 291 | |
274 | 292 | def test_should_postgresql_enum_convert(): |
275 | field = assert_column_conversion( | |
276 | postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field | |
277 | ) | |
293 | field = get_field(postgresql.ENUM("one", "two", name="two_numbers")) | |
278 | 294 | field_type = field.type() |
279 | assert field_type.__class__.__name__ == "two_numbers" | |
280 | 295 | assert isinstance(field_type, graphene.Enum) |
281 | assert hasattr(field_type, "two") | |
296 | assert field_type._meta.name == "TwoNumbers" | |
297 | assert hasattr(field_type, "ONE") | |
298 | assert not hasattr(field_type, "one") | |
299 | assert hasattr(field_type, "TWO") | |
300 | assert not hasattr(field_type, "two") | |
282 | 301 | |
283 | 302 | |
284 | 303 | def test_should_postgresql_py_enum_convert(): |
285 | field = assert_column_conversion( | |
286 | postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field | |
287 | ) | |
304 | field = get_field(postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers")) | |
288 | 305 | field_type = field.type() |
289 | assert field_type.__class__.__name__ == "TwoNumbers" | |
306 | assert field_type._meta.name == "TwoNumbers" | |
290 | 307 | assert isinstance(field_type, graphene.Enum) |
291 | assert hasattr(field_type, "two") | |
308 | assert hasattr(field_type, "ONE") | |
309 | assert not hasattr(field_type, "one") | |
310 | assert hasattr(field_type, "TWO") | |
311 | assert not hasattr(field_type, "two") | |
292 | 312 | |
293 | 313 | |
294 | 314 | def test_should_postgresql_array_convert(): |
295 | assert_column_conversion(postgresql.ARRAY(types.Integer), graphene.List) | |
315 | field = get_field(postgresql.ARRAY(types.Integer)) | |
316 | assert isinstance(field.type, graphene.List) | |
317 | assert field.type.of_type == graphene.Int | |
318 | ||
319 | ||
320 | def test_should_array_convert(): | |
321 | field = get_field(types.ARRAY(types.Integer)) | |
322 | assert isinstance(field.type, graphene.List) | |
323 | assert field.type.of_type == graphene.Int | |
296 | 324 | |
297 | 325 | |
298 | 326 | def test_should_postgresql_json_convert(): |
299 | assert_column_conversion(postgresql.JSON(), JSONString) | |
327 | assert get_field(postgresql.JSON()).type == graphene.JSONString | |
300 | 328 | |
301 | 329 | |
302 | 330 | def test_should_postgresql_jsonb_convert(): |
303 | assert_column_conversion(postgresql.JSONB(), JSONString) | |
331 | assert get_field(postgresql.JSONB()).type == graphene.JSONString | |
304 | 332 | |
305 | 333 | |
306 | 334 | def test_should_postgresql_hstore_convert(): |
307 | assert_column_conversion(postgresql.HSTORE(), JSONString) | |
335 | assert get_field(postgresql.HSTORE()).type == graphene.JSONString | |
308 | 336 | |
309 | 337 | |
310 | 338 | def test_should_composite_convert(): |
311 | class CompositeClass(object): | |
339 | registry = Registry() | |
340 | ||
341 | class CompositeClass: | |
312 | 342 | def __init__(self, col1, col2): |
313 | 343 | self.col1 = col1 |
314 | 344 | self.col2 = col2 |
315 | 345 | |
316 | registry = Registry() | |
317 | ||
318 | 346 | @convert_sqlalchemy_composite.register(CompositeClass, registry) |
319 | 347 | def convert_composite_class(composite, registry): |
320 | 348 | return graphene.String(description=composite.doc) |
321 | 349 | |
322 | assert_composite_conversion( | |
323 | CompositeClass, | |
324 | (Column(types.Unicode(50)), Column(types.Unicode(50))), | |
325 | graphene.String, | |
350 | field = convert_sqlalchemy_composite( | |
351 | composite(CompositeClass, (Column(types.Unicode(50)), Column(types.Unicode(50))), doc="Custom Help Text"), | |
326 | 352 | registry, |
327 | ) | |
353 | mock_resolver, | |
354 | ) | |
355 | assert isinstance(field, graphene.String) | |
328 | 356 | |
329 | 357 | |
330 | 358 | def test_should_unknown_sqlalchemy_composite_raise_exception(): |
331 | registry = Registry() | |
332 | ||
333 | with raises(Exception) as excinfo: | |
334 | ||
335 | class CompositeClass(object): | |
336 | def __init__(self, col1, col2): | |
337 | self.col1 = col1 | |
338 | self.col2 = col2 | |
339 | ||
340 | assert_composite_conversion( | |
341 | CompositeClass, | |
342 | (Column(types.Unicode(50)), Column(types.Unicode(50))), | |
343 | graphene.String, | |
344 | registry, | |
359 | class CompositeClass: | |
360 | def __init__(self, col1, col2): | |
361 | self.col1 = col1 | |
362 | self.col2 = col2 | |
363 | ||
364 | re_err = "Don't know how to convert the composite field" | |
365 | with pytest.raises(Exception, match=re_err): | |
366 | convert_sqlalchemy_composite( | |
367 | composite(CompositeFullName, (Column(types.Unicode(50)), Column(types.Unicode(50)))), | |
368 | Registry(), | |
369 | mock_resolver, | |
345 | 370 | ) |
346 | ||
347 | assert "Don't know how to convert the composite field" in str(excinfo.value) |
0 | from enum import Enum as PyEnum | |
1 | ||
2 | import pytest | |
3 | from sqlalchemy.types import Enum as SQLAlchemyEnumType | |
4 | ||
5 | from graphene import Enum | |
6 | ||
7 | from ..enums import _convert_sa_to_graphene_enum, enum_for_field | |
8 | from ..types import SQLAlchemyObjectType | |
9 | from .models import HairKind, Pet | |
10 | ||
11 | ||
12 | def test_convert_sa_to_graphene_enum_bad_type(): | |
13 | re_err = "Expected sqlalchemy.types.Enum, but got: 'foo'" | |
14 | with pytest.raises(TypeError, match=re_err): | |
15 | _convert_sa_to_graphene_enum("foo") | |
16 | ||
17 | ||
18 | def test_convert_sa_to_graphene_enum_based_on_py_enum(): | |
19 | class Color(PyEnum): | |
20 | RED = 1 | |
21 | GREEN = 2 | |
22 | BLUE = 3 | |
23 | ||
24 | sa_enum = SQLAlchemyEnumType(Color) | |
25 | graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") | |
26 | assert isinstance(graphene_enum, type(Enum)) | |
27 | assert graphene_enum._meta.name == "Color" | |
28 | assert graphene_enum._meta.enum is Color | |
29 | ||
30 | ||
31 | def test_convert_sa_to_graphene_enum_based_on_py_enum_with_bad_names(): | |
32 | class Color(PyEnum): | |
33 | red = 1 | |
34 | green = 2 | |
35 | blue = 3 | |
36 | ||
37 | sa_enum = SQLAlchemyEnumType(Color) | |
38 | graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") | |
39 | assert isinstance(graphene_enum, type(Enum)) | |
40 | assert graphene_enum._meta.name == "Color" | |
41 | assert graphene_enum._meta.enum is not Color | |
42 | assert [ | |
43 | (key, value.value) | |
44 | for key, value in graphene_enum._meta.enum.__members__.items() | |
45 | ] == [("RED", 1), ("GREEN", 2), ("BLUE", 3)] | |
46 | ||
47 | ||
48 | def test_convert_sa_enum_to_graphene_enum_based_on_list_named(): | |
49 | sa_enum = SQLAlchemyEnumType("red", "green", "blue", name="color_values") | |
50 | graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") | |
51 | assert isinstance(graphene_enum, type(Enum)) | |
52 | assert graphene_enum._meta.name == "ColorValues" | |
53 | assert [ | |
54 | (key, value.value) | |
55 | for key, value in graphene_enum._meta.enum.__members__.items() | |
56 | ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] | |
57 | ||
58 | ||
59 | def test_convert_sa_enum_to_graphene_enum_based_on_list_unnamed(): | |
60 | sa_enum = SQLAlchemyEnumType("red", "green", "blue") | |
61 | graphene_enum = _convert_sa_to_graphene_enum(sa_enum, "FallbackName") | |
62 | assert isinstance(graphene_enum, type(Enum)) | |
63 | assert graphene_enum._meta.name == "FallbackName" | |
64 | assert [ | |
65 | (key, value.value) | |
66 | for key, value in graphene_enum._meta.enum.__members__.items() | |
67 | ] == [("RED", 'red'), ("GREEN", 'green'), ("BLUE", 'blue')] | |
68 | ||
69 | ||
70 | def test_convert_sa_enum_to_graphene_enum_based_on_list_without_name(): | |
71 | sa_enum = SQLAlchemyEnumType("red", "green", "blue") | |
72 | re_err = r"No type name specified for Enum\('red', 'green', 'blue'\)" | |
73 | with pytest.raises(TypeError, match=re_err): | |
74 | _convert_sa_to_graphene_enum(sa_enum) | |
75 | ||
76 | ||
77 | def test_enum_for_field(): | |
78 | class PetType(SQLAlchemyObjectType): | |
79 | class Meta: | |
80 | model = Pet | |
81 | ||
82 | enum = enum_for_field(PetType, 'pet_kind') | |
83 | assert isinstance(enum, type(Enum)) | |
84 | assert enum._meta.name == "PetKind" | |
85 | assert [ | |
86 | (key, value.value) | |
87 | for key, value in enum._meta.enum.__members__.items() | |
88 | ] == [("CAT", 'cat'), ("DOG", 'dog')] | |
89 | enum2 = enum_for_field(PetType, 'pet_kind') | |
90 | assert enum2 is enum | |
91 | enum2 = PetType.enum_for_field('pet_kind') | |
92 | assert enum2 is enum | |
93 | ||
94 | enum = enum_for_field(PetType, 'hair_kind') | |
95 | assert isinstance(enum, type(Enum)) | |
96 | assert enum._meta.name == "HairKind" | |
97 | assert enum._meta.enum is HairKind | |
98 | enum2 = PetType.enum_for_field('hair_kind') | |
99 | assert enum2 is enum | |
100 | ||
101 | re_err = r"Cannot get PetType\.other_kind" | |
102 | with pytest.raises(TypeError, match=re_err): | |
103 | enum_for_field(PetType, 'other_kind') | |
104 | with pytest.raises(TypeError, match=re_err): | |
105 | PetType.enum_for_field('other_kind') | |
106 | ||
107 | re_err = r"PetType\.name does not map to enum column" | |
108 | with pytest.raises(TypeError, match=re_err): | |
109 | enum_for_field(PetType, 'name') | |
110 | with pytest.raises(TypeError, match=re_err): | |
111 | PetType.enum_for_field('name') | |
112 | ||
113 | re_err = r"Expected a field name, but got: None" | |
114 | with pytest.raises(TypeError, match=re_err): | |
115 | enum_for_field(PetType, None) | |
116 | with pytest.raises(TypeError, match=re_err): | |
117 | PetType.enum_for_field(None) | |
118 | ||
119 | re_err = "Expected SQLAlchemyObjectType, but got: None" | |
120 | with pytest.raises(TypeError, match=re_err): | |
121 | enum_for_field(None, 'other_kind') |
0 | 0 | import pytest |
1 | from promise import Promise | |
1 | 2 | |
2 | from graphene.relay import Connection | |
3 | from graphene import NonNull, ObjectType | |
4 | from graphene.relay import Connection, Node | |
3 | 5 | |
4 | from ..fields import SQLAlchemyConnectionField | |
6 | from ..fields import (SQLAlchemyConnectionField, | |
7 | UnsortedSQLAlchemyConnectionField) | |
5 | 8 | from ..types import SQLAlchemyObjectType |
6 | from ..utils import sort_argument_for_model | |
7 | from .models import Editor | |
9 | from .models import Editor as EditorModel | |
8 | 10 | from .models import Pet as PetModel |
9 | 11 | |
10 | 12 | |
11 | 13 | class Pet(SQLAlchemyObjectType): |
12 | 14 | class Meta: |
13 | 15 | model = PetModel |
16 | interfaces = (Node,) | |
14 | 17 | |
15 | 18 | |
16 | class PetConn(Connection): | |
19 | class Editor(SQLAlchemyObjectType): | |
17 | 20 | class Meta: |
18 | node = Pet | |
21 | model = EditorModel | |
22 | ||
23 | ## | |
24 | # SQLAlchemyConnectionField | |
25 | ## | |
26 | ||
27 | ||
28 | def test_nonnull_sqlalachemy_connection(): | |
29 | field = SQLAlchemyConnectionField(NonNull(Pet.connection)) | |
30 | assert isinstance(field.type, NonNull) | |
31 | assert issubclass(field.type.of_type, Connection) | |
32 | assert field.type.of_type._meta.node is Pet | |
33 | ||
34 | ||
35 | def test_required_sqlalachemy_connection(): | |
36 | field = SQLAlchemyConnectionField(Pet.connection, required=True) | |
37 | assert isinstance(field.type, NonNull) | |
38 | assert issubclass(field.type.of_type, Connection) | |
39 | assert field.type.of_type._meta.node is Pet | |
40 | ||
41 | ||
42 | def test_promise_connection_resolver(): | |
43 | def resolver(_obj, _info): | |
44 | return Promise.resolve([]) | |
45 | ||
46 | result = UnsortedSQLAlchemyConnectionField.connection_resolver( | |
47 | resolver, Pet.connection, Pet, None, None | |
48 | ) | |
49 | assert isinstance(result, Promise) | |
50 | ||
51 | ||
52 | def test_type_assert_sqlalchemy_object_type(): | |
53 | with pytest.raises(AssertionError, match="only accepts SQLAlchemyObjectType"): | |
54 | SQLAlchemyConnectionField(ObjectType).type | |
55 | ||
56 | ||
57 | def test_type_assert_object_has_connection(): | |
58 | with pytest.raises(AssertionError, match="doesn't have a connection"): | |
59 | SQLAlchemyConnectionField(Editor).type | |
60 | ||
61 | ## | |
62 | # UnsortedSQLAlchemyConnectionField | |
63 | ## | |
19 | 64 | |
20 | 65 | |
21 | 66 | def test_sort_added_by_default(): |
22 | arg = SQLAlchemyConnectionField(PetConn) | |
23 | assert "sort" in arg.args | |
24 | assert arg.args["sort"] == sort_argument_for_model(PetModel) | |
67 | field = SQLAlchemyConnectionField(Pet.connection) | |
68 | assert "sort" in field.args | |
69 | assert field.args["sort"] == Pet.sort_argument() | |
25 | 70 | |
26 | 71 | |
27 | 72 | def test_sort_can_be_removed(): |
28 | arg = SQLAlchemyConnectionField(PetConn, sort=None) | |
29 | assert "sort" not in arg.args | |
73 | field = SQLAlchemyConnectionField(Pet.connection, sort=None) | |
74 | assert "sort" not in field.args | |
30 | 75 | |
31 | 76 | |
32 | 77 | def test_custom_sort(): |
33 | arg = SQLAlchemyConnectionField(PetConn, sort=sort_argument_for_model(Editor)) | |
34 | assert arg.args["sort"] == sort_argument_for_model(Editor) | |
78 | field = SQLAlchemyConnectionField(Pet.connection, sort=Editor.sort_argument()) | |
79 | assert field.args["sort"] == Editor.sort_argument() | |
35 | 80 | |
36 | 81 | |
37 | def test_init_raises(): | |
38 | with pytest.raises(Exception, match="Cannot create sort"): | |
82 | def test_sort_init_raises(): | |
83 | with pytest.raises(TypeError, match="Cannot create sort"): | |
39 | 84 | SQLAlchemyConnectionField(Connection) |
0 | import pytest | |
1 | from sqlalchemy import create_engine | |
2 | from sqlalchemy.orm import scoped_session, sessionmaker | |
3 | ||
4 | 0 | import graphene |
5 | from graphene.relay import Connection, Node | |
6 | ||
1 | from graphene.relay import Node | |
2 | ||
3 | from ..converter import convert_sqlalchemy_composite | |
7 | 4 | from ..fields import SQLAlchemyConnectionField |
8 | from ..registry import reset_global_registry | |
9 | from ..types import SQLAlchemyObjectType | |
10 | from ..utils import sort_argument_for_model, sort_enum_for_model | |
11 | from .models import Article, Base, Editor, Hairkind, Pet, Reporter | |
12 | ||
13 | db = create_engine("sqlite:///test_sqlalchemy.sqlite3") | |
14 | ||
15 | ||
16 | @pytest.yield_fixture(scope="function") | |
17 | def session(): | |
18 | reset_global_registry() | |
19 | connection = db.engine.connect() | |
20 | transaction = connection.begin() | |
21 | Base.metadata.create_all(connection) | |
22 | ||
23 | # options = dict(bind=connection, binds={}) | |
24 | session_factory = sessionmaker(bind=connection) | |
25 | session = scoped_session(session_factory) | |
26 | ||
27 | yield session | |
28 | ||
29 | # Finalize test here | |
30 | transaction.rollback() | |
31 | connection.close() | |
32 | session.remove() | |
33 | ||
34 | ||
35 | def setup_fixtures(session): | |
36 | pet = Pet(name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG) | |
5 | from ..types import ORMField, SQLAlchemyObjectType | |
6 | from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter | |
7 | from .utils import to_std_dicts | |
8 | ||
9 | ||
10 | def add_test_data(session): | |
11 | reporter = Reporter( | |
12 | first_name='John', last_name='Doe', favorite_pet_kind='cat') | |
13 | session.add(reporter) | |
14 | pet = Pet(name='Garfield', pet_kind='cat', hair_kind=HairKind.SHORT) | |
37 | 15 | session.add(pet) |
38 | reporter = Reporter(first_name="ABA", last_name="X") | |
39 | session.add(reporter) | |
40 | reporter2 = Reporter(first_name="ABO", last_name="Y") | |
41 | session.add(reporter2) | |
42 | article = Article(headline="Hi!") | |
16 | pet.reporters.append(reporter) | |
17 | article = Article(headline='Hi!') | |
43 | 18 | article.reporter = reporter |
44 | 19 | session.add(article) |
45 | editor = Editor(name="John") | |
20 | reporter = Reporter( | |
21 | first_name='Jane', last_name='Roe', favorite_pet_kind='dog') | |
22 | session.add(reporter) | |
23 | pet = Pet(name='Lassie', pet_kind='dog', hair_kind=HairKind.LONG) | |
24 | pet.reporters.append(reporter) | |
25 | session.add(pet) | |
26 | editor = Editor(name="Jack") | |
46 | 27 | session.add(editor) |
47 | 28 | session.commit() |
48 | 29 | |
49 | 30 | |
50 | def test_should_query_well(session): | |
51 | setup_fixtures(session) | |
31 | def test_query_fields(session): | |
32 | add_test_data(session) | |
33 | ||
34 | @convert_sqlalchemy_composite.register(CompositeFullName) | |
35 | def convert_composite_class(composite, registry): | |
36 | return graphene.String() | |
52 | 37 | |
53 | 38 | class ReporterType(SQLAlchemyObjectType): |
54 | 39 | class Meta: |
58 | 43 | reporter = graphene.Field(ReporterType) |
59 | 44 | reporters = graphene.List(ReporterType) |
60 | 45 | |
61 | def resolve_reporter(self, *args, **kwargs): | |
46 | def resolve_reporter(self, _info): | |
62 | 47 | return session.query(Reporter).first() |
63 | 48 | |
64 | def resolve_reporters(self, *args, **kwargs): | |
49 | def resolve_reporters(self, _info): | |
65 | 50 | return session.query(Reporter) |
66 | 51 | |
67 | 52 | query = """ |
68 | query ReporterQuery { | |
53 | query { | |
69 | 54 | reporter { |
70 | firstName, | |
71 | lastName, | |
72 | ||
55 | firstName | |
56 | columnProp | |
57 | hybridProp | |
58 | compositeProp | |
73 | 59 | } |
74 | 60 | reporters { |
75 | 61 | firstName |
77 | 63 | } |
78 | 64 | """ |
79 | 65 | expected = { |
80 | "reporter": {"firstName": "ABA", "lastName": "X", "email": None}, | |
81 | "reporters": [{"firstName": "ABA"}, {"firstName": "ABO"}], | |
66 | "reporter": { | |
67 | "firstName": "John", | |
68 | "hybridProp": "John", | |
69 | "columnProp": 2, | |
70 | "compositeProp": "John Doe", | |
71 | }, | |
72 | "reporters": [{"firstName": "John"}, {"firstName": "Jane"}], | |
82 | 73 | } |
83 | 74 | schema = graphene.Schema(query=Query) |
84 | 75 | result = schema.execute(query) |
85 | 76 | assert not result.errors |
86 | assert result.data == expected | |
87 | ||
88 | ||
89 | def test_should_query_enums(session): | |
90 | setup_fixtures(session) | |
91 | ||
92 | class PetType(SQLAlchemyObjectType): | |
93 | class Meta: | |
94 | model = Pet | |
95 | ||
96 | class Query(graphene.ObjectType): | |
97 | pet = graphene.Field(PetType) | |
98 | ||
99 | def resolve_pet(self, *args, **kwargs): | |
100 | return session.query(Pet).first() | |
101 | ||
102 | query = """ | |
103 | query PetQuery { | |
104 | pet { | |
105 | name, | |
106 | petKind | |
107 | hairKind | |
108 | } | |
109 | } | |
110 | """ | |
111 | expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} | |
112 | schema = graphene.Schema(query=Query) | |
113 | result = schema.execute(query) | |
114 | assert not result.errors | |
115 | assert result.data == expected, result.data | |
116 | ||
117 | ||
118 | def test_enum_parameter(session): | |
119 | setup_fixtures(session) | |
120 | ||
121 | class PetType(SQLAlchemyObjectType): | |
122 | class Meta: | |
123 | model = Pet | |
124 | ||
125 | class Query(graphene.ObjectType): | |
126 | pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type)) | |
127 | ||
128 | def resolve_pet(self, info, kind=None, *args, **kwargs): | |
129 | query = session.query(Pet) | |
130 | if kind: | |
131 | query = query.filter(Pet.pet_kind == kind) | |
132 | return query.first() | |
133 | ||
134 | query = """ | |
135 | query PetQuery($kind: pet_kind) { | |
136 | pet(kind: $kind) { | |
137 | name, | |
138 | petKind | |
139 | hairKind | |
140 | } | |
141 | } | |
142 | """ | |
143 | expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} | |
144 | schema = graphene.Schema(query=Query) | |
145 | result = schema.execute(query, variables={"kind": "cat"}) | |
146 | assert not result.errors | |
147 | assert result.data == {"pet": None} | |
148 | result = schema.execute(query, variables={"kind": "dog"}) | |
149 | assert not result.errors | |
150 | assert result.data == expected, result.data | |
151 | ||
152 | ||
153 | def test_py_enum_parameter(session): | |
154 | setup_fixtures(session) | |
155 | ||
156 | class PetType(SQLAlchemyObjectType): | |
157 | class Meta: | |
158 | model = Pet | |
159 | ||
160 | class Query(graphene.ObjectType): | |
161 | pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['hair_kind'].type.of_type)) | |
162 | ||
163 | def resolve_pet(self, info, kind=None, *args, **kwargs): | |
164 | query = session.query(Pet) | |
165 | if kind: | |
166 | # XXX Why kind passed in as a str instead of a Hairkind instance? | |
167 | query = query.filter(Pet.hair_kind == Hairkind(kind)) | |
168 | return query.first() | |
169 | ||
170 | query = """ | |
171 | query PetQuery($kind: Hairkind) { | |
172 | pet(kind: $kind) { | |
173 | name, | |
174 | petKind | |
175 | hairKind | |
176 | } | |
177 | } | |
178 | """ | |
179 | expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} | |
180 | schema = graphene.Schema(query=Query) | |
181 | result = schema.execute(query, variables={"kind": "SHORT"}) | |
182 | assert not result.errors | |
183 | assert result.data == {"pet": None} | |
184 | result = schema.execute(query, variables={"kind": "LONG"}) | |
185 | assert not result.errors | |
186 | assert result.data == expected, result.data | |
187 | ||
188 | ||
189 | def test_should_node(session): | |
190 | setup_fixtures(session) | |
77 | result = to_std_dicts(result.data) | |
78 | assert result == expected | |
79 | ||
80 | ||
81 | def test_query_node(session): | |
82 | add_test_data(session) | |
191 | 83 | |
192 | 84 | class ReporterNode(SQLAlchemyObjectType): |
193 | 85 | class Meta: |
203 | 95 | model = Article |
204 | 96 | interfaces = (Node,) |
205 | 97 | |
206 | # @classmethod | |
207 | # def get_node(cls, id, info): | |
208 | # return Article(id=1, headline='Article node') | |
209 | ||
210 | class ArticleConnection(Connection): | |
211 | class Meta: | |
212 | node = ArticleNode | |
213 | ||
214 | 98 | class Query(graphene.ObjectType): |
215 | 99 | node = Node.Field() |
216 | 100 | reporter = graphene.Field(ReporterNode) |
217 | article = graphene.Field(ArticleNode) | |
218 | all_articles = SQLAlchemyConnectionField(ArticleConnection) | |
219 | ||
220 | def resolve_reporter(self, *args, **kwargs): | |
101 | all_articles = SQLAlchemyConnectionField(ArticleNode.connection) | |
102 | ||
103 | def resolve_reporter(self, _info): | |
221 | 104 | return session.query(Reporter).first() |
222 | 105 | |
223 | def resolve_article(self, *args, **kwargs): | |
224 | return session.query(Article).first() | |
225 | ||
226 | query = """ | |
227 | query ReporterQuery { | |
106 | query = """ | |
107 | query { | |
228 | 108 | reporter { |
229 | id, | |
230 | firstName, | |
109 | id | |
110 | firstName | |
231 | 111 | articles { |
232 | 112 | edges { |
233 | 113 | node { |
235 | 115 | } |
236 | 116 | } |
237 | 117 | } |
238 | lastName, | |
239 | ||
240 | 118 | } |
241 | 119 | allArticles { |
242 | 120 | edges { |
259 | 137 | expected = { |
260 | 138 | "reporter": { |
261 | 139 | "id": "UmVwb3J0ZXJOb2RlOjE=", |
262 | "firstName": "ABA", | |
263 | "lastName": "X", | |
264 | "email": None, | |
140 | "firstName": "John", | |
265 | 141 | "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, |
266 | 142 | }, |
267 | 143 | "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, |
270 | 146 | schema = graphene.Schema(query=Query) |
271 | 147 | result = schema.execute(query, context_value={"session": session}) |
272 | 148 | assert not result.errors |
273 | assert result.data == expected | |
274 | ||
275 | ||
276 | def test_should_custom_identifier(session): | |
277 | setup_fixtures(session) | |
149 | result = to_std_dicts(result.data) | |
150 | assert result == expected | |
151 | ||
152 | ||
153 | def test_orm_field(session): | |
154 | add_test_data(session) | |
155 | ||
156 | @convert_sqlalchemy_composite.register(CompositeFullName) | |
157 | def convert_composite_class(composite, registry): | |
158 | return graphene.String() | |
159 | ||
160 | class ReporterType(SQLAlchemyObjectType): | |
161 | class Meta: | |
162 | model = Reporter | |
163 | interfaces = (Node,) | |
164 | ||
165 | first_name_v2 = ORMField(model_attr='first_name') | |
166 | hybrid_prop_v2 = ORMField(model_attr='hybrid_prop') | |
167 | column_prop_v2 = ORMField(model_attr='column_prop') | |
168 | composite_prop = ORMField() | |
169 | favorite_article_v2 = ORMField(model_attr='favorite_article') | |
170 | articles_v2 = ORMField(model_attr='articles') | |
171 | ||
172 | class ArticleType(SQLAlchemyObjectType): | |
173 | class Meta: | |
174 | model = Article | |
175 | interfaces = (Node,) | |
176 | ||
177 | class Query(graphene.ObjectType): | |
178 | reporter = graphene.Field(ReporterType) | |
179 | ||
180 | def resolve_reporter(self, _info): | |
181 | return session.query(Reporter).first() | |
182 | ||
183 | query = """ | |
184 | query { | |
185 | reporter { | |
186 | firstNameV2 | |
187 | hybridPropV2 | |
188 | columnPropV2 | |
189 | compositeProp | |
190 | favoriteArticleV2 { | |
191 | headline | |
192 | } | |
193 | articlesV2(first: 1) { | |
194 | edges { | |
195 | node { | |
196 | headline | |
197 | } | |
198 | } | |
199 | } | |
200 | } | |
201 | } | |
202 | """ | |
203 | expected = { | |
204 | "reporter": { | |
205 | "firstNameV2": "John", | |
206 | "hybridPropV2": "John", | |
207 | "columnPropV2": 2, | |
208 | "compositeProp": "John Doe", | |
209 | "favoriteArticleV2": {"headline": "Hi!"}, | |
210 | "articlesV2": {"edges": [{"node": {"headline": "Hi!"}}]}, | |
211 | }, | |
212 | } | |
213 | schema = graphene.Schema(query=Query) | |
214 | result = schema.execute(query, context_value={"session": session}) | |
215 | assert not result.errors | |
216 | result = to_std_dicts(result.data) | |
217 | assert result == expected | |
218 | ||
219 | ||
220 | def test_custom_identifier(session): | |
221 | add_test_data(session) | |
278 | 222 | |
279 | 223 | class EditorNode(SQLAlchemyObjectType): |
280 | 224 | class Meta: |
281 | 225 | model = Editor |
282 | 226 | interfaces = (Node,) |
283 | 227 | |
284 | class EditorConnection(Connection): | |
285 | class Meta: | |
286 | node = EditorNode | |
287 | ||
288 | 228 | class Query(graphene.ObjectType): |
289 | 229 | node = Node.Field() |
290 | all_editors = SQLAlchemyConnectionField(EditorConnection) | |
291 | ||
292 | query = """ | |
293 | query EditorQuery { | |
230 | all_editors = SQLAlchemyConnectionField(EditorNode.connection) | |
231 | ||
232 | query = """ | |
233 | query { | |
294 | 234 | allEditors { |
295 | 235 | edges { |
296 | 236 | node { |
297 | id, | |
237 | id | |
298 | 238 | name |
299 | 239 | } |
300 | 240 | } |
307 | 247 | } |
308 | 248 | """ |
309 | 249 | expected = { |
310 | "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "John"}}]}, | |
311 | "node": {"name": "John"}, | |
250 | "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "Jack"}}]}, | |
251 | "node": {"name": "Jack"}, | |
312 | 252 | } |
313 | 253 | |
314 | 254 | schema = graphene.Schema(query=Query) |
315 | 255 | result = schema.execute(query, context_value={"session": session}) |
316 | 256 | assert not result.errors |
317 | assert result.data == expected | |
318 | ||
319 | ||
320 | def test_should_mutate_well(session): | |
321 | setup_fixtures(session) | |
257 | result = to_std_dicts(result.data) | |
258 | assert result == expected | |
259 | ||
260 | ||
261 | def test_mutation(session): | |
262 | add_test_data(session) | |
322 | 263 | |
323 | 264 | class EditorNode(SQLAlchemyObjectType): |
324 | 265 | class Meta: |
363 | 304 | create_article = CreateArticle.Field() |
364 | 305 | |
365 | 306 | query = """ |
366 | mutation ArticleCreator { | |
307 | mutation { | |
367 | 308 | createArticle( |
368 | 309 | headline: "My Article" |
369 | 310 | reporterId: "1" |
384 | 325 | "ok": True, |
385 | 326 | "article": { |
386 | 327 | "headline": "My Article", |
387 | "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "ABA"}, | |
328 | "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "John"}, | |
388 | 329 | }, |
389 | 330 | } |
390 | 331 | } |
392 | 333 | schema = graphene.Schema(query=Query, mutation=Mutation) |
393 | 334 | result = schema.execute(query, context_value={"session": session}) |
394 | 335 | assert not result.errors |
395 | assert result.data == expected | |
396 | ||
397 | ||
398 | def sort_setup(session): | |
399 | pets = [ | |
400 | Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG), | |
401 | Pet(id=22, name="Alf", pet_kind="cat", hair_kind=Hairkind.LONG), | |
402 | Pet(id=3, name="Barf", pet_kind="dog", hair_kind=Hairkind.LONG), | |
403 | ] | |
404 | session.add_all(pets) | |
405 | session.commit() | |
406 | ||
407 | ||
408 | def test_sort(session): | |
409 | sort_setup(session) | |
410 | ||
411 | class PetNode(SQLAlchemyObjectType): | |
412 | class Meta: | |
413 | model = Pet | |
414 | interfaces = (Node,) | |
415 | ||
416 | class PetConnection(Connection): | |
417 | class Meta: | |
418 | node = PetNode | |
419 | ||
420 | class Query(graphene.ObjectType): | |
421 | defaultSort = SQLAlchemyConnectionField(PetConnection) | |
422 | nameSort = SQLAlchemyConnectionField(PetConnection) | |
423 | multipleSort = SQLAlchemyConnectionField(PetConnection) | |
424 | descSort = SQLAlchemyConnectionField(PetConnection) | |
425 | singleColumnSort = SQLAlchemyConnectionField( | |
426 | PetConnection, sort=graphene.Argument(sort_enum_for_model(Pet)) | |
427 | ) | |
428 | noDefaultSort = SQLAlchemyConnectionField( | |
429 | PetConnection, sort=sort_argument_for_model(Pet, False) | |
430 | ) | |
431 | noSort = SQLAlchemyConnectionField(PetConnection, sort=None) | |
432 | ||
433 | query = """ | |
434 | query sortTest { | |
435 | defaultSort{ | |
436 | edges{ | |
437 | node{ | |
438 | id | |
439 | } | |
440 | } | |
441 | } | |
442 | nameSort(sort: name_asc){ | |
443 | edges{ | |
444 | node{ | |
445 | name | |
446 | } | |
447 | } | |
448 | } | |
449 | multipleSort(sort: [pet_kind_asc, name_desc]){ | |
450 | edges{ | |
451 | node{ | |
452 | name | |
453 | petKind | |
454 | } | |
455 | } | |
456 | } | |
457 | descSort(sort: [name_desc]){ | |
458 | edges{ | |
459 | node{ | |
460 | name | |
461 | } | |
462 | } | |
463 | } | |
464 | singleColumnSort(sort: name_desc){ | |
465 | edges{ | |
466 | node{ | |
467 | name | |
468 | } | |
469 | } | |
470 | } | |
471 | noDefaultSort(sort: name_asc){ | |
472 | edges{ | |
473 | node{ | |
474 | name | |
475 | } | |
476 | } | |
477 | } | |
478 | } | |
479 | """ | |
480 | ||
481 | def makeNodes(nodeList): | |
482 | nodes = [{"node": item} for item in nodeList] | |
483 | return {"edges": nodes} | |
484 | ||
485 | expected = { | |
486 | "defaultSort": makeNodes( | |
487 | [{"id": "UGV0Tm9kZToy"}, {"id": "UGV0Tm9kZToz"}, {"id": "UGV0Tm9kZToyMg=="}] | |
488 | ), | |
489 | "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), | |
490 | "noDefaultSort": makeNodes( | |
491 | [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] | |
492 | ), | |
493 | "multipleSort": makeNodes( | |
494 | [ | |
495 | {"name": "Alf", "petKind": "cat"}, | |
496 | {"name": "Lassie", "petKind": "dog"}, | |
497 | {"name": "Barf", "petKind": "dog"}, | |
498 | ] | |
499 | ), | |
500 | "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), | |
501 | "singleColumnSort": makeNodes( | |
502 | [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] | |
503 | ), | |
504 | } # yapf: disable | |
505 | ||
506 | schema = graphene.Schema(query=Query) | |
507 | result = schema.execute(query, context_value={"session": session}) | |
508 | assert not result.errors | |
509 | assert result.data == expected | |
510 | ||
511 | queryError = """ | |
512 | query sortTest { | |
513 | singleColumnSort(sort: [pet_kind_asc, name_desc]){ | |
514 | edges{ | |
515 | node{ | |
516 | name | |
517 | } | |
518 | } | |
519 | } | |
520 | } | |
521 | """ | |
522 | result = schema.execute(queryError, context_value={"session": session}) | |
523 | assert result.errors is not None | |
524 | ||
525 | queryNoSort = """ | |
526 | query sortTest { | |
527 | noDefaultSort{ | |
528 | edges{ | |
529 | node{ | |
530 | name | |
531 | } | |
532 | } | |
533 | } | |
534 | noSort{ | |
535 | edges{ | |
536 | node{ | |
537 | name | |
538 | } | |
539 | } | |
540 | } | |
541 | } | |
542 | """ | |
543 | ||
544 | expectedNoSort = { | |
545 | "noDefaultSort": makeNodes( | |
546 | [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] | |
547 | ), | |
548 | "noSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), | |
549 | } # yapf: disable | |
550 | ||
551 | result = schema.execute(queryNoSort, context_value={"session": session}) | |
552 | assert not result.errors | |
553 | for key, value in result.data.items(): | |
554 | assert set(node["node"]["name"] for node in value["edges"]) == set( | |
555 | node["node"]["name"] for node in expectedNoSort[key]["edges"] | |
556 | ) | |
336 | result = to_std_dicts(result.data) | |
337 | assert result == expected |
0 | import graphene | |
1 | ||
2 | from ..types import SQLAlchemyObjectType | |
3 | from .models import HairKind, Pet, Reporter | |
4 | from .test_query import add_test_data, to_std_dicts | |
5 | ||
6 | ||
7 | def test_query_pet_kinds(session): | |
8 | add_test_data(session) | |
9 | ||
10 | class PetType(SQLAlchemyObjectType): | |
11 | ||
12 | class Meta: | |
13 | model = Pet | |
14 | ||
15 | class ReporterType(SQLAlchemyObjectType): | |
16 | class Meta: | |
17 | model = Reporter | |
18 | ||
19 | class Query(graphene.ObjectType): | |
20 | reporter = graphene.Field(ReporterType) | |
21 | reporters = graphene.List(ReporterType) | |
22 | pets = graphene.List(PetType, kind=graphene.Argument( | |
23 | PetType.enum_for_field('pet_kind'))) | |
24 | ||
25 | def resolve_reporter(self, _info): | |
26 | return session.query(Reporter).first() | |
27 | ||
28 | def resolve_reporters(self, _info): | |
29 | return session.query(Reporter) | |
30 | ||
31 | def resolve_pets(self, _info, kind): | |
32 | query = session.query(Pet) | |
33 | if kind: | |
34 | query = query.filter_by(pet_kind=kind) | |
35 | return query | |
36 | ||
37 | query = """ | |
38 | query ReporterQuery { | |
39 | reporter { | |
40 | firstName | |
41 | lastName | |
42 | ||
43 | favoritePetKind | |
44 | pets { | |
45 | name | |
46 | petKind | |
47 | } | |
48 | } | |
49 | reporters { | |
50 | firstName | |
51 | favoritePetKind | |
52 | } | |
53 | pets(kind: DOG) { | |
54 | name | |
55 | petKind | |
56 | } | |
57 | } | |
58 | """ | |
59 | expected = { | |
60 | 'reporter': { | |
61 | 'firstName': 'John', | |
62 | 'lastName': 'Doe', | |
63 | 'email': None, | |
64 | 'favoritePetKind': 'CAT', | |
65 | 'pets': [{ | |
66 | 'name': 'Garfield', | |
67 | 'petKind': 'CAT' | |
68 | }] | |
69 | }, | |
70 | 'reporters': [{ | |
71 | 'firstName': 'John', | |
72 | 'favoritePetKind': 'CAT', | |
73 | }, { | |
74 | 'firstName': 'Jane', | |
75 | 'favoritePetKind': 'DOG', | |
76 | }], | |
77 | 'pets': [{ | |
78 | 'name': 'Lassie', | |
79 | 'petKind': 'DOG' | |
80 | }] | |
81 | } | |
82 | schema = graphene.Schema(query=Query) | |
83 | result = schema.execute(query) | |
84 | assert not result.errors | |
85 | assert result.data == expected | |
86 | ||
87 | ||
88 | def test_query_more_enums(session): | |
89 | add_test_data(session) | |
90 | ||
91 | class PetType(SQLAlchemyObjectType): | |
92 | class Meta: | |
93 | model = Pet | |
94 | ||
95 | class Query(graphene.ObjectType): | |
96 | pet = graphene.Field(PetType) | |
97 | ||
98 | def resolve_pet(self, _info): | |
99 | return session.query(Pet).first() | |
100 | ||
101 | query = """ | |
102 | query PetQuery { | |
103 | pet { | |
104 | name, | |
105 | petKind | |
106 | hairKind | |
107 | } | |
108 | } | |
109 | """ | |
110 | expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} | |
111 | schema = graphene.Schema(query=Query) | |
112 | result = schema.execute(query) | |
113 | assert not result.errors | |
114 | result = to_std_dicts(result.data) | |
115 | assert result == expected | |
116 | ||
117 | ||
118 | def test_enum_as_argument(session): | |
119 | add_test_data(session) | |
120 | ||
121 | class PetType(SQLAlchemyObjectType): | |
122 | class Meta: | |
123 | model = Pet | |
124 | ||
125 | class Query(graphene.ObjectType): | |
126 | pet = graphene.Field( | |
127 | PetType, | |
128 | kind=graphene.Argument(PetType.enum_for_field('pet_kind'))) | |
129 | ||
130 | def resolve_pet(self, info, kind=None): | |
131 | query = session.query(Pet) | |
132 | if kind: | |
133 | query = query.filter(Pet.pet_kind == kind) | |
134 | return query.first() | |
135 | ||
136 | query = """ | |
137 | query PetQuery($kind: PetKind) { | |
138 | pet(kind: $kind) { | |
139 | name, | |
140 | petKind | |
141 | hairKind | |
142 | } | |
143 | } | |
144 | """ | |
145 | ||
146 | schema = graphene.Schema(query=Query) | |
147 | result = schema.execute(query, variables={"kind": "CAT"}) | |
148 | assert not result.errors | |
149 | expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} | |
150 | assert result.data == expected | |
151 | result = schema.execute(query, variables={"kind": "DOG"}) | |
152 | assert not result.errors | |
153 | expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} | |
154 | result = to_std_dicts(result.data) | |
155 | assert result == expected | |
156 | ||
157 | ||
158 | def test_py_enum_as_argument(session): | |
159 | add_test_data(session) | |
160 | ||
161 | class PetType(SQLAlchemyObjectType): | |
162 | class Meta: | |
163 | model = Pet | |
164 | ||
165 | class Query(graphene.ObjectType): | |
166 | pet = graphene.Field( | |
167 | PetType, | |
168 | kind=graphene.Argument(PetType._meta.fields["hair_kind"].type.of_type), | |
169 | ) | |
170 | ||
171 | def resolve_pet(self, _info, kind=None): | |
172 | query = session.query(Pet) | |
173 | if kind: | |
174 | # enum arguments are expected to be strings, not PyEnums | |
175 | query = query.filter(Pet.hair_kind == HairKind(kind)) | |
176 | return query.first() | |
177 | ||
178 | query = """ | |
179 | query PetQuery($kind: HairKind) { | |
180 | pet(kind: $kind) { | |
181 | name, | |
182 | petKind | |
183 | hairKind | |
184 | } | |
185 | } | |
186 | """ | |
187 | ||
188 | schema = graphene.Schema(query=Query) | |
189 | result = schema.execute(query, variables={"kind": "SHORT"}) | |
190 | assert not result.errors | |
191 | expected = {"pet": {"name": "Garfield", "petKind": "CAT", "hairKind": "SHORT"}} | |
192 | assert result.data == expected | |
193 | result = schema.execute(query, variables={"kind": "LONG"}) | |
194 | assert not result.errors | |
195 | expected = {"pet": {"name": "Lassie", "petKind": "DOG", "hairKind": "LONG"}} | |
196 | result = to_std_dicts(result.data) | |
197 | assert result == expected |
0 | 0 | import pytest |
1 | from sqlalchemy.types import Enum as SQLAlchemyEnum | |
2 | ||
3 | from graphene import Enum as GrapheneEnum | |
1 | 4 | |
2 | 5 | from ..registry import Registry |
3 | 6 | from ..types import SQLAlchemyObjectType |
7 | from ..utils import EnumValue | |
4 | 8 | from .models import Pet |
5 | 9 | |
6 | 10 | |
7 | def test_register_incorrect_objecttype(): | |
8 | reg = Registry() | |
9 | ||
10 | class Spam: | |
11 | pass | |
12 | ||
13 | with pytest.raises(AssertionError) as excinfo: | |
14 | reg.register(Spam) | |
15 | ||
16 | assert "Only classes of type SQLAlchemyObjectType can be registered" in str( | |
17 | excinfo.value | |
18 | ) | |
19 | ||
20 | ||
21 | def test_register_objecttype(): | |
11 | def test_register_object_type(): | |
22 | 12 | reg = Registry() |
23 | 13 | |
24 | 14 | class PetType(SQLAlchemyObjectType): |
26 | 16 | model = Pet |
27 | 17 | registry = reg |
28 | 18 | |
29 | try: | |
30 | reg.register(PetType) | |
31 | except AssertionError: | |
32 | pytest.fail("expected no AssertionError") | |
19 | reg.register(PetType) | |
20 | assert reg.get_type_for_model(Pet) is PetType | |
21 | ||
22 | ||
23 | def test_register_incorrect_object_type(): | |
24 | reg = Registry() | |
25 | ||
26 | class Spam: | |
27 | pass | |
28 | ||
29 | re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" | |
30 | with pytest.raises(TypeError, match=re_err): | |
31 | reg.register(Spam) | |
32 | ||
33 | ||
34 | def test_register_orm_field(): | |
35 | reg = Registry() | |
36 | ||
37 | class PetType(SQLAlchemyObjectType): | |
38 | class Meta: | |
39 | model = Pet | |
40 | registry = reg | |
41 | ||
42 | reg.register_orm_field(PetType, "name", Pet.name) | |
43 | assert reg.get_orm_field_for_graphene_field(PetType, "name") is Pet.name | |
44 | ||
45 | ||
46 | def test_register_orm_field_incorrect_types(): | |
47 | reg = Registry() | |
48 | ||
49 | class Spam: | |
50 | pass | |
51 | ||
52 | re_err = "Expected SQLAlchemyObjectType, but got: .*Spam" | |
53 | with pytest.raises(TypeError, match=re_err): | |
54 | reg.register_orm_field(Spam, "name", Pet.name) | |
55 | ||
56 | class PetType(SQLAlchemyObjectType): | |
57 | class Meta: | |
58 | model = Pet | |
59 | registry = reg | |
60 | ||
61 | re_err = "Expected a field name, but got: .*Spam" | |
62 | with pytest.raises(TypeError, match=re_err): | |
63 | reg.register_orm_field(PetType, Spam, Pet.name) | |
64 | ||
65 | ||
66 | def test_register_enum(): | |
67 | reg = Registry() | |
68 | ||
69 | sa_enum = SQLAlchemyEnum("cat", "dog") | |
70 | graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)]) | |
71 | ||
72 | reg.register_enum(sa_enum, graphene_enum) | |
73 | assert reg.get_graphene_enum_for_sa_enum(sa_enum) is graphene_enum | |
74 | ||
75 | ||
76 | def test_register_enum_incorrect_types(): | |
77 | reg = Registry() | |
78 | ||
79 | sa_enum = SQLAlchemyEnum("cat", "dog") | |
80 | graphene_enum = GrapheneEnum("PetKind", [("CAT", 1), ("DOG", 2)]) | |
81 | ||
82 | re_err = r"Expected Graphene Enum, but got: Enum\('cat', 'dog'\)" | |
83 | with pytest.raises(TypeError, match=re_err): | |
84 | reg.register_enum(sa_enum, sa_enum) | |
85 | ||
86 | re_err = r"Expected SQLAlchemyEnumType, but got: .*PetKind.*" | |
87 | with pytest.raises(TypeError, match=re_err): | |
88 | reg.register_enum(graphene_enum, graphene_enum) | |
89 | ||
90 | ||
91 | def test_register_sort_enum(): | |
92 | reg = Registry() | |
93 | ||
94 | class PetType(SQLAlchemyObjectType): | |
95 | class Meta: | |
96 | model = Pet | |
97 | registry = reg | |
98 | ||
99 | sort_enum = GrapheneEnum( | |
100 | "PetSort", | |
101 | [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))], | |
102 | ) | |
103 | ||
104 | reg.register_sort_enum(PetType, sort_enum) | |
105 | assert reg.get_sort_enum_for_object_type(PetType) is sort_enum | |
106 | ||
107 | ||
108 | def test_register_sort_enum_incorrect_types(): | |
109 | reg = Registry() | |
110 | ||
111 | class PetType(SQLAlchemyObjectType): | |
112 | class Meta: | |
113 | model = Pet | |
114 | registry = reg | |
115 | ||
116 | sort_enum = GrapheneEnum( | |
117 | "PetSort", | |
118 | [("ID", EnumValue("id", Pet.id)), ("NAME", EnumValue("name", Pet.name))], | |
119 | ) | |
120 | ||
121 | re_err = r"Expected SQLAlchemyObjectType, but got: .*PetSort.*" | |
122 | with pytest.raises(TypeError, match=re_err): | |
123 | reg.register_sort_enum(sort_enum, sort_enum) | |
124 | ||
125 | re_err = r"Expected Graphene Enum, but got: .*PetType.*" | |
126 | with pytest.raises(TypeError, match=re_err): | |
127 | reg.register_sort_enum(PetType, PetType) |
0 | from py.test import raises | |
1 | ||
2 | from ..registry import Registry | |
3 | from ..types import SQLAlchemyObjectType | |
4 | from .models import Reporter | |
5 | ||
6 | ||
7 | def test_should_raise_if_no_model(): | |
8 | with raises(Exception) as excinfo: | |
9 | ||
10 | class Character1(SQLAlchemyObjectType): | |
11 | pass | |
12 | ||
13 | assert "valid SQLAlchemy Model" in str(excinfo.value) | |
14 | ||
15 | ||
16 | def test_should_raise_if_model_is_invalid(): | |
17 | with raises(Exception) as excinfo: | |
18 | ||
19 | class Character2(SQLAlchemyObjectType): | |
20 | class Meta: | |
21 | model = 1 | |
22 | ||
23 | assert "valid SQLAlchemy Model" in str(excinfo.value) | |
24 | ||
25 | ||
26 | def test_should_map_fields_correctly(): | |
27 | class ReporterType2(SQLAlchemyObjectType): | |
28 | class Meta: | |
29 | model = Reporter | |
30 | registry = Registry() | |
31 | ||
32 | assert list(ReporterType2._meta.fields.keys()) == [ | |
33 | "id", | |
34 | "first_name", | |
35 | "last_name", | |
36 | "email", | |
37 | "pets", | |
38 | "articles", | |
39 | "favorite_article", | |
40 | ] | |
41 | ||
42 | ||
43 | def test_should_map_only_few_fields(): | |
44 | class Reporter2(SQLAlchemyObjectType): | |
45 | class Meta: | |
46 | model = Reporter | |
47 | only_fields = ("id", "email") | |
48 | assert list(Reporter2._meta.fields.keys()) == ["id", "email"] |
0 | import pytest | |
1 | import sqlalchemy as sa | |
2 | ||
3 | from graphene import Argument, Enum, List, ObjectType, Schema | |
4 | from graphene.relay import Node | |
5 | ||
6 | from ..fields import SQLAlchemyConnectionField | |
7 | from ..types import SQLAlchemyObjectType | |
8 | from ..utils import to_type_name | |
9 | from .models import Base, HairKind, Pet | |
10 | from .test_query import to_std_dicts | |
11 | ||
12 | ||
13 | def add_pets(session): | |
14 | pets = [ | |
15 | Pet(id=1, name="Lassie", pet_kind="dog", hair_kind=HairKind.LONG), | |
16 | Pet(id=2, name="Barf", pet_kind="dog", hair_kind=HairKind.LONG), | |
17 | Pet(id=3, name="Alf", pet_kind="cat", hair_kind=HairKind.LONG), | |
18 | ] | |
19 | session.add_all(pets) | |
20 | session.commit() | |
21 | ||
22 | ||
23 | def test_sort_enum(): | |
24 | class PetType(SQLAlchemyObjectType): | |
25 | class Meta: | |
26 | model = Pet | |
27 | ||
28 | sort_enum = PetType.sort_enum() | |
29 | assert isinstance(sort_enum, type(Enum)) | |
30 | assert sort_enum._meta.name == "PetTypeSortEnum" | |
31 | assert list(sort_enum._meta.enum.__members__) == [ | |
32 | "ID_ASC", | |
33 | "ID_DESC", | |
34 | "NAME_ASC", | |
35 | "NAME_DESC", | |
36 | "PET_KIND_ASC", | |
37 | "PET_KIND_DESC", | |
38 | "HAIR_KIND_ASC", | |
39 | "HAIR_KIND_DESC", | |
40 | "REPORTER_ID_ASC", | |
41 | "REPORTER_ID_DESC", | |
42 | ] | |
43 | assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" | |
44 | assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" | |
45 | assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC" | |
46 | assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC" | |
47 | ||
48 | ||
49 | def test_sort_enum_with_custom_name(): | |
50 | class PetType(SQLAlchemyObjectType): | |
51 | class Meta: | |
52 | model = Pet | |
53 | ||
54 | sort_enum = PetType.sort_enum(name="CustomSortName") | |
55 | assert isinstance(sort_enum, type(Enum)) | |
56 | assert sort_enum._meta.name == "CustomSortName" | |
57 | ||
58 | ||
59 | def test_sort_enum_cache(): | |
60 | class PetType(SQLAlchemyObjectType): | |
61 | class Meta: | |
62 | model = Pet | |
63 | ||
64 | sort_enum = PetType.sort_enum() | |
65 | sort_enum_2 = PetType.sort_enum() | |
66 | assert sort_enum_2 is sort_enum | |
67 | sort_enum_2 = PetType.sort_enum(name="PetTypeSortEnum") | |
68 | assert sort_enum_2 is sort_enum | |
69 | err_msg = "Sort enum for PetType has already been customized" | |
70 | with pytest.raises(ValueError, match=err_msg): | |
71 | PetType.sort_enum(name="CustomSortName") | |
72 | with pytest.raises(ValueError, match=err_msg): | |
73 | PetType.sort_enum(only_fields=["id"]) | |
74 | with pytest.raises(ValueError, match=err_msg): | |
75 | PetType.sort_enum(only_indexed=True) | |
76 | with pytest.raises(ValueError, match=err_msg): | |
77 | PetType.sort_enum(get_symbol_name=lambda: "foo") | |
78 | ||
79 | ||
80 | def test_sort_enum_with_excluded_field_in_object_type(): | |
81 | class PetType(SQLAlchemyObjectType): | |
82 | class Meta: | |
83 | model = Pet | |
84 | exclude_fields = ["reporter_id"] | |
85 | ||
86 | sort_enum = PetType.sort_enum() | |
87 | assert list(sort_enum._meta.enum.__members__) == [ | |
88 | "ID_ASC", | |
89 | "ID_DESC", | |
90 | "NAME_ASC", | |
91 | "NAME_DESC", | |
92 | "PET_KIND_ASC", | |
93 | "PET_KIND_DESC", | |
94 | "HAIR_KIND_ASC", | |
95 | "HAIR_KIND_DESC", | |
96 | ] | |
97 | ||
98 | ||
99 | def test_sort_enum_only_fields(): | |
100 | class PetType(SQLAlchemyObjectType): | |
101 | class Meta: | |
102 | model = Pet | |
103 | ||
104 | sort_enum = PetType.sort_enum(only_fields=["id", "name"]) | |
105 | assert list(sort_enum._meta.enum.__members__) == [ | |
106 | "ID_ASC", | |
107 | "ID_DESC", | |
108 | "NAME_ASC", | |
109 | "NAME_DESC", | |
110 | ] | |
111 | ||
112 | ||
113 | def test_sort_argument(): | |
114 | class PetType(SQLAlchemyObjectType): | |
115 | class Meta: | |
116 | model = Pet | |
117 | ||
118 | sort_arg = PetType.sort_argument() | |
119 | assert isinstance(sort_arg, Argument) | |
120 | ||
121 | assert isinstance(sort_arg.type, List) | |
122 | sort_enum = sort_arg.type._of_type | |
123 | assert isinstance(sort_enum, type(Enum)) | |
124 | assert sort_enum._meta.name == "PetTypeSortEnum" | |
125 | assert list(sort_enum._meta.enum.__members__) == [ | |
126 | "ID_ASC", | |
127 | "ID_DESC", | |
128 | "NAME_ASC", | |
129 | "NAME_DESC", | |
130 | "PET_KIND_ASC", | |
131 | "PET_KIND_DESC", | |
132 | "HAIR_KIND_ASC", | |
133 | "HAIR_KIND_DESC", | |
134 | "REPORTER_ID_ASC", | |
135 | "REPORTER_ID_DESC", | |
136 | ] | |
137 | assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" | |
138 | assert str(sort_enum.ID_DESC.value.value) == "pets.id DESC" | |
139 | assert str(sort_enum.HAIR_KIND_ASC.value.value) == "pets.hair_kind ASC" | |
140 | assert str(sort_enum.HAIR_KIND_DESC.value.value) == "pets.hair_kind DESC" | |
141 | ||
142 | assert sort_arg.default_value == ["ID_ASC"] | |
143 | assert str(sort_enum.ID_ASC.value.value) == "pets.id ASC" | |
144 | ||
145 | ||
146 | def test_sort_argument_with_excluded_fields_in_object_type(): | |
147 | class PetType(SQLAlchemyObjectType): | |
148 | class Meta: | |
149 | model = Pet | |
150 | exclude_fields = ["hair_kind", "reporter_id"] | |
151 | ||
152 | sort_arg = PetType.sort_argument() | |
153 | sort_enum = sort_arg.type._of_type | |
154 | assert list(sort_enum._meta.enum.__members__) == [ | |
155 | "ID_ASC", | |
156 | "ID_DESC", | |
157 | "NAME_ASC", | |
158 | "NAME_DESC", | |
159 | "PET_KIND_ASC", | |
160 | "PET_KIND_DESC", | |
161 | ] | |
162 | assert sort_arg.default_value == ["ID_ASC"] | |
163 | ||
164 | ||
165 | def test_sort_argument_only_fields(): | |
166 | class PetType(SQLAlchemyObjectType): | |
167 | class Meta: | |
168 | model = Pet | |
169 | only_fields = ["id", "pet_kind"] | |
170 | ||
171 | sort_arg = PetType.sort_argument() | |
172 | sort_enum = sort_arg.type._of_type | |
173 | assert list(sort_enum._meta.enum.__members__) == [ | |
174 | "ID_ASC", | |
175 | "ID_DESC", | |
176 | "PET_KIND_ASC", | |
177 | "PET_KIND_DESC", | |
178 | ] | |
179 | assert sort_arg.default_value == ["ID_ASC"] | |
180 | ||
181 | ||
182 | def test_sort_argument_for_multi_column_pk(): | |
183 | class MultiPkTestModel(Base): | |
184 | __tablename__ = "multi_pk_test_table" | |
185 | foo = sa.Column(sa.Integer, primary_key=True) | |
186 | bar = sa.Column(sa.Integer, primary_key=True) | |
187 | ||
188 | class MultiPkTestType(SQLAlchemyObjectType): | |
189 | class Meta: | |
190 | model = MultiPkTestModel | |
191 | ||
192 | sort_arg = MultiPkTestType.sort_argument() | |
193 | assert sort_arg.default_value == ["FOO_ASC", "BAR_ASC"] | |
194 | ||
195 | ||
196 | def test_sort_argument_only_indexed(): | |
197 | class IndexedTestModel(Base): | |
198 | __tablename__ = "indexed_test_table" | |
199 | id = sa.Column(sa.Integer, primary_key=True) | |
200 | foo = sa.Column(sa.Integer, index=False) | |
201 | bar = sa.Column(sa.Integer, index=True) | |
202 | ||
203 | class IndexedTestType(SQLAlchemyObjectType): | |
204 | class Meta: | |
205 | model = IndexedTestModel | |
206 | ||
207 | sort_arg = IndexedTestType.sort_argument(only_indexed=True) | |
208 | sort_enum = sort_arg.type._of_type | |
209 | assert list(sort_enum._meta.enum.__members__) == [ | |
210 | "ID_ASC", | |
211 | "ID_DESC", | |
212 | "BAR_ASC", | |
213 | "BAR_DESC", | |
214 | ] | |
215 | assert sort_arg.default_value == ["ID_ASC"] | |
216 | ||
217 | ||
218 | def test_sort_argument_with_custom_symbol_names(): | |
219 | class PetType(SQLAlchemyObjectType): | |
220 | class Meta: | |
221 | model = Pet | |
222 | ||
223 | def get_symbol_name(column_name, sort_asc=True): | |
224 | return to_type_name(column_name) + ("Up" if sort_asc else "Down") | |
225 | ||
226 | sort_arg = PetType.sort_argument(get_symbol_name=get_symbol_name) | |
227 | sort_enum = sort_arg.type._of_type | |
228 | assert list(sort_enum._meta.enum.__members__) == [ | |
229 | "IdUp", | |
230 | "IdDown", | |
231 | "NameUp", | |
232 | "NameDown", | |
233 | "PetKindUp", | |
234 | "PetKindDown", | |
235 | "HairKindUp", | |
236 | "HairKindDown", | |
237 | "ReporterIdUp", | |
238 | "ReporterIdDown", | |
239 | ] | |
240 | assert sort_arg.default_value == ["IdUp"] | |
241 | ||
242 | ||
243 | def test_sort_query(session): | |
244 | add_pets(session) | |
245 | ||
246 | class PetNode(SQLAlchemyObjectType): | |
247 | class Meta: | |
248 | model = Pet | |
249 | interfaces = (Node,) | |
250 | ||
251 | class Query(ObjectType): | |
252 | defaultSort = SQLAlchemyConnectionField(PetNode.connection) | |
253 | nameSort = SQLAlchemyConnectionField(PetNode.connection) | |
254 | multipleSort = SQLAlchemyConnectionField(PetNode.connection) | |
255 | descSort = SQLAlchemyConnectionField(PetNode.connection) | |
256 | singleColumnSort = SQLAlchemyConnectionField( | |
257 | PetNode.connection, sort=Argument(PetNode.sort_enum()) | |
258 | ) | |
259 | noDefaultSort = SQLAlchemyConnectionField( | |
260 | PetNode.connection, sort=PetNode.sort_argument(has_default=False) | |
261 | ) | |
262 | noSort = SQLAlchemyConnectionField(PetNode.connection, sort=None) | |
263 | ||
264 | query = """ | |
265 | query sortTest { | |
266 | defaultSort { | |
267 | edges { | |
268 | node { | |
269 | name | |
270 | } | |
271 | } | |
272 | } | |
273 | nameSort(sort: NAME_ASC) { | |
274 | edges { | |
275 | node { | |
276 | name | |
277 | } | |
278 | } | |
279 | } | |
280 | multipleSort(sort: [PET_KIND_ASC, NAME_DESC]) { | |
281 | edges { | |
282 | node { | |
283 | name | |
284 | petKind | |
285 | } | |
286 | } | |
287 | } | |
288 | descSort(sort: [NAME_DESC]) { | |
289 | edges { | |
290 | node { | |
291 | name | |
292 | } | |
293 | } | |
294 | } | |
295 | singleColumnSort(sort: NAME_DESC) { | |
296 | edges { | |
297 | node { | |
298 | name | |
299 | } | |
300 | } | |
301 | } | |
302 | noDefaultSort(sort: NAME_ASC) { | |
303 | edges { | |
304 | node { | |
305 | name | |
306 | } | |
307 | } | |
308 | } | |
309 | } | |
310 | """ | |
311 | ||
312 | def makeNodes(nodeList): | |
313 | nodes = [{"node": item} for item in nodeList] | |
314 | return {"edges": nodes} | |
315 | ||
316 | expected = { | |
317 | "defaultSort": makeNodes( | |
318 | [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] | |
319 | ), | |
320 | "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), | |
321 | "noDefaultSort": makeNodes( | |
322 | [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] | |
323 | ), | |
324 | "multipleSort": makeNodes( | |
325 | [ | |
326 | {"name": "Alf", "petKind": "CAT"}, | |
327 | {"name": "Lassie", "petKind": "DOG"}, | |
328 | {"name": "Barf", "petKind": "DOG"}, | |
329 | ] | |
330 | ), | |
331 | "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), | |
332 | "singleColumnSort": makeNodes( | |
333 | [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] | |
334 | ), | |
335 | } # yapf: disable | |
336 | ||
337 | schema = Schema(query=Query) | |
338 | result = schema.execute(query, context_value={"session": session}) | |
339 | assert not result.errors | |
340 | result = to_std_dicts(result.data) | |
341 | assert result == expected | |
342 | ||
343 | queryError = """ | |
344 | query sortTest { | |
345 | singleColumnSort(sort: [PET_KIND_ASC, NAME_DESC]) { | |
346 | edges { | |
347 | node { | |
348 | name | |
349 | } | |
350 | } | |
351 | } | |
352 | } | |
353 | """ | |
354 | result = schema.execute(queryError, context_value={"session": session}) | |
355 | assert result.errors is not None | |
356 | assert '"sort" has invalid value' in result.errors[0].message | |
357 | ||
358 | queryNoSort = """ | |
359 | query sortTest { | |
360 | noDefaultSort { | |
361 | edges { | |
362 | node { | |
363 | name | |
364 | } | |
365 | } | |
366 | } | |
367 | noSort { | |
368 | edges { | |
369 | node { | |
370 | name | |
371 | } | |
372 | } | |
373 | } | |
374 | } | |
375 | """ | |
376 | ||
377 | result = schema.execute(queryNoSort, context_value={"session": session}) | |
378 | assert not result.errors | |
379 | # TODO: SQLite usually returns the results ordered by primary key, | |
380 | # so we cannot test this way whether sorting actually happens or not. | |
381 | # Also, no sort order is guaranteed by SQLite if "no order" by is used. | |
382 | assert [node["node"]["name"] for node in result.data["noSort"]["edges"]] == [ | |
383 | node["node"]["name"] for node in result.data["noDefaultSort"]["edges"] | |
384 | ] |
0 | from collections import OrderedDict | |
1 | ||
0 | import mock | |
1 | import pytest | |
2 | 2 | import six # noqa F401 |
3 | from promise import Promise | |
4 | ||
5 | from graphene import (Connection, Field, Int, Interface, Node, ObjectType, | |
6 | is_node) | |
7 | ||
3 | ||
4 | from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, | |
5 | ObjectType, Schema, String) | |
6 | from graphene.relay import Connection | |
7 | ||
8 | from ..converter import convert_sqlalchemy_composite | |
8 | 9 | from ..fields import (SQLAlchemyConnectionField, |
9 | UnsortedSQLAlchemyConnectionField, | |
10 | UnsortedSQLAlchemyConnectionField, createConnectionField, | |
10 | 11 | registerConnectionFieldFactory, |
11 | 12 | unregisterConnectionFieldFactory) |
12 | from ..registry import Registry | |
13 | from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions | |
14 | from .models import Article, Reporter | |
15 | ||
16 | registry = Registry() | |
17 | ||
18 | ||
19 | class Character(SQLAlchemyObjectType): | |
20 | """Character description""" | |
21 | ||
22 | class Meta: | |
23 | model = Reporter | |
24 | registry = registry | |
25 | ||
26 | ||
27 | class Human(SQLAlchemyObjectType): | |
28 | """Human description""" | |
29 | ||
30 | pub_date = Int() | |
31 | ||
32 | class Meta: | |
33 | model = Article | |
34 | exclude_fields = ("id",) | |
35 | registry = registry | |
36 | interfaces = (Node,) | |
37 | ||
38 | ||
39 | def test_sqlalchemy_interface(): | |
40 | assert issubclass(Node, Interface) | |
41 | assert issubclass(Node, Node) | |
42 | ||
43 | ||
44 | # @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) | |
45 | # def test_sqlalchemy_get_node(get): | |
46 | # human = Human.get_node(1, None) | |
47 | # get.assert_called_with(id=1) | |
48 | # assert human.id == 1 | |
49 | ||
50 | ||
51 | def test_objecttype_registered(): | |
52 | assert issubclass(Character, ObjectType) | |
53 | assert Character._meta.model == Reporter | |
54 | assert list(Character._meta.fields.keys()) == [ | |
13 | from ..types import ORMField, SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions | |
14 | from .models import Article, CompositeFullName, Pet, Reporter | |
15 | ||
16 | ||
17 | def test_should_raise_if_no_model(): | |
18 | re_err = r"valid SQLAlchemy Model" | |
19 | with pytest.raises(Exception, match=re_err): | |
20 | class Character1(SQLAlchemyObjectType): | |
21 | pass | |
22 | ||
23 | ||
24 | def test_should_raise_if_model_is_invalid(): | |
25 | re_err = r"valid SQLAlchemy Model" | |
26 | with pytest.raises(Exception, match=re_err): | |
27 | class Character(SQLAlchemyObjectType): | |
28 | class Meta: | |
29 | model = 1 | |
30 | ||
31 | ||
32 | def test_sqlalchemy_node(session): | |
33 | class ReporterType(SQLAlchemyObjectType): | |
34 | class Meta: | |
35 | model = Reporter | |
36 | interfaces = (Node,) | |
37 | ||
38 | reporter_id_field = ReporterType._meta.fields["id"] | |
39 | assert isinstance(reporter_id_field, GlobalID) | |
40 | ||
41 | reporter = Reporter() | |
42 | session.add(reporter) | |
43 | session.commit() | |
44 | info = mock.Mock(context={'session': session}) | |
45 | reporter_node = ReporterType.get_node(info, reporter.id) | |
46 | assert reporter == reporter_node | |
47 | ||
48 | ||
49 | def test_connection(): | |
50 | class ReporterType(SQLAlchemyObjectType): | |
51 | class Meta: | |
52 | model = Reporter | |
53 | interfaces = (Node,) | |
54 | ||
55 | assert issubclass(ReporterType.connection, Connection) | |
56 | ||
57 | ||
58 | def test_sqlalchemy_default_fields(): | |
59 | @convert_sqlalchemy_composite.register(CompositeFullName) | |
60 | def convert_composite_class(composite, registry): | |
61 | return String() | |
62 | ||
63 | class ReporterType(SQLAlchemyObjectType): | |
64 | class Meta: | |
65 | model = Reporter | |
66 | interfaces = (Node,) | |
67 | ||
68 | class ArticleType(SQLAlchemyObjectType): | |
69 | class Meta: | |
70 | model = Article | |
71 | interfaces = (Node,) | |
72 | ||
73 | assert list(ReporterType._meta.fields.keys()) == [ | |
74 | # Columns | |
75 | "column_prop", # SQLAlchemy retuns column properties first | |
55 | 76 | "id", |
56 | 77 | "first_name", |
57 | 78 | "last_name", |
58 | 79 | "email", |
80 | "favorite_pet_kind", | |
81 | # Composite | |
82 | "composite_prop", | |
83 | # Hybrid | |
84 | "hybrid_prop", | |
85 | # Relationship | |
59 | 86 | "pets", |
60 | 87 | "articles", |
61 | 88 | "favorite_article", |
62 | 89 | ] |
63 | 90 | |
64 | ||
65 | # def test_sqlalchemynode_idfield(): | |
66 | # idfield = Node._meta.fields_map['id'] | |
67 | # assert isinstance(idfield, GlobalIDField) | |
68 | ||
69 | ||
70 | # def test_node_idfield(): | |
71 | # idfield = Human._meta.fields_map['id'] | |
72 | # assert isinstance(idfield, GlobalIDField) | |
73 | ||
74 | ||
75 | def test_node_replacedfield(): | |
76 | idfield = Human._meta.fields["pub_date"] | |
77 | assert isinstance(idfield, Field) | |
78 | assert idfield.type == Int | |
79 | ||
80 | ||
81 | def test_object_type(): | |
82 | class Human(SQLAlchemyObjectType): | |
83 | """Human description""" | |
84 | ||
85 | pub_date = Int() | |
86 | ||
91 | # column | |
92 | first_name_field = ReporterType._meta.fields['first_name'] | |
93 | assert first_name_field.type == String | |
94 | assert first_name_field.description == "First name" | |
95 | ||
96 | # column_property | |
97 | column_prop_field = ReporterType._meta.fields['column_prop'] | |
98 | assert column_prop_field.type == Int | |
99 | # "doc" is ignored by column_property | |
100 | assert column_prop_field.description is None | |
101 | ||
102 | # composite | |
103 | full_name_field = ReporterType._meta.fields['composite_prop'] | |
104 | assert full_name_field.type == String | |
105 | # "doc" is ignored by composite | |
106 | assert full_name_field.description is None | |
107 | ||
108 | # hybrid_property | |
109 | hybrid_prop = ReporterType._meta.fields['hybrid_prop'] | |
110 | assert hybrid_prop.type == String | |
111 | # "doc" is ignored by hybrid_property | |
112 | assert hybrid_prop.description is None | |
113 | ||
114 | # relationship | |
115 | favorite_article_field = ReporterType._meta.fields['favorite_article'] | |
116 | assert isinstance(favorite_article_field, Dynamic) | |
117 | assert favorite_article_field.type().type == ArticleType | |
118 | assert favorite_article_field.type().description is None | |
119 | ||
120 | ||
121 | def test_sqlalchemy_override_fields(): | |
122 | @convert_sqlalchemy_composite.register(CompositeFullName) | |
123 | def convert_composite_class(composite, registry): | |
124 | return String() | |
125 | ||
126 | class ReporterMixin(object): | |
127 | # columns | |
128 | first_name = ORMField(required=True) | |
129 | last_name = ORMField(description='Overridden') | |
130 | ||
131 | class ReporterType(SQLAlchemyObjectType, ReporterMixin): | |
132 | class Meta: | |
133 | model = Reporter | |
134 | interfaces = (Node,) | |
135 | ||
136 | # columns | |
137 | email = ORMField(deprecation_reason='Overridden') | |
138 | email_v2 = ORMField(model_attr='email', type=Int) | |
139 | ||
140 | # column_property | |
141 | column_prop = ORMField(type=String) | |
142 | ||
143 | # composite | |
144 | composite_prop = ORMField() | |
145 | ||
146 | # hybrid_property | |
147 | hybrid_prop = ORMField(description='Overridden') | |
148 | ||
149 | # relationships | |
150 | favorite_article = ORMField(description='Overridden') | |
151 | articles = ORMField(deprecation_reason='Overridden') | |
152 | pets = ORMField(description='Overridden') | |
153 | ||
154 | class ArticleType(SQLAlchemyObjectType): | |
87 | 155 | class Meta: |
88 | 156 | model = Article |
89 | # exclude_fields = ('id', ) | |
90 | registry = registry | |
91 | interfaces = (Node,) | |
92 | ||
93 | assert issubclass(Human, ObjectType) | |
94 | assert list(Human._meta.fields.keys()) == [ | |
95 | "id", | |
96 | "headline", | |
97 | "pub_date", | |
98 | "reporter_id", | |
99 | "reporter", | |
100 | ] | |
101 | assert is_node(Human) | |
102 | ||
103 | ||
104 | # Test Custom SQLAlchemyObjectType Implementation | |
105 | class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): | |
106 | class Meta: | |
107 | abstract = True | |
108 | ||
109 | ||
110 | class CustomCharacter(CustomSQLAlchemyObjectType): | |
111 | """Character description""" | |
112 | ||
113 | class Meta: | |
114 | model = Reporter | |
115 | registry = registry | |
116 | ||
117 | ||
118 | def test_custom_objecttype_registered(): | |
119 | assert issubclass(CustomCharacter, ObjectType) | |
120 | assert CustomCharacter._meta.model == Reporter | |
121 | assert list(CustomCharacter._meta.fields.keys()) == [ | |
122 | "id", | |
157 | interfaces = (Node,) | |
158 | ||
159 | class PetType(SQLAlchemyObjectType): | |
160 | class Meta: | |
161 | model = Pet | |
162 | interfaces = (Node,) | |
163 | use_connection = False | |
164 | ||
165 | assert list(ReporterType._meta.fields.keys()) == [ | |
166 | # Fields from ReporterMixin | |
123 | 167 | "first_name", |
124 | 168 | "last_name", |
169 | # Fields from ReporterType | |
125 | 170 | "email", |
171 | "email_v2", | |
172 | "column_prop", | |
173 | "composite_prop", | |
174 | "hybrid_prop", | |
175 | "favorite_article", | |
176 | "articles", | |
177 | "pets", | |
178 | # Then the automatic SQLAlchemy fields | |
179 | "id", | |
180 | "favorite_pet_kind", | |
181 | ] | |
182 | ||
183 | first_name_field = ReporterType._meta.fields['first_name'] | |
184 | assert isinstance(first_name_field.type, NonNull) | |
185 | assert first_name_field.type.of_type == String | |
186 | assert first_name_field.description == "First name" | |
187 | assert first_name_field.deprecation_reason is None | |
188 | ||
189 | last_name_field = ReporterType._meta.fields['last_name'] | |
190 | assert last_name_field.type == String | |
191 | assert last_name_field.description == "Overridden" | |
192 | assert last_name_field.deprecation_reason is None | |
193 | ||
194 | email_field = ReporterType._meta.fields['email'] | |
195 | assert email_field.type == String | |
196 | assert email_field.description == "Email" | |
197 | assert email_field.deprecation_reason == "Overridden" | |
198 | ||
199 | email_field_v2 = ReporterType._meta.fields['email_v2'] | |
200 | assert email_field_v2.type == Int | |
201 | assert email_field_v2.description == "Email" | |
202 | assert email_field_v2.deprecation_reason is None | |
203 | ||
204 | hybrid_prop_field = ReporterType._meta.fields['hybrid_prop'] | |
205 | assert hybrid_prop_field.type == String | |
206 | assert hybrid_prop_field.description == "Overridden" | |
207 | assert hybrid_prop_field.deprecation_reason is None | |
208 | ||
209 | column_prop_field_v2 = ReporterType._meta.fields['column_prop'] | |
210 | assert column_prop_field_v2.type == String | |
211 | assert column_prop_field_v2.description is None | |
212 | assert column_prop_field_v2.deprecation_reason is None | |
213 | ||
214 | composite_prop_field = ReporterType._meta.fields['composite_prop'] | |
215 | assert composite_prop_field.type == String | |
216 | assert composite_prop_field.description is None | |
217 | assert composite_prop_field.deprecation_reason is None | |
218 | ||
219 | favorite_article_field = ReporterType._meta.fields['favorite_article'] | |
220 | assert isinstance(favorite_article_field, Dynamic) | |
221 | assert favorite_article_field.type().type == ArticleType | |
222 | assert favorite_article_field.type().description == 'Overridden' | |
223 | ||
224 | articles_field = ReporterType._meta.fields['articles'] | |
225 | assert isinstance(articles_field, Dynamic) | |
226 | assert isinstance(articles_field.type(), UnsortedSQLAlchemyConnectionField) | |
227 | assert articles_field.type().deprecation_reason == "Overridden" | |
228 | ||
229 | pets_field = ReporterType._meta.fields['pets'] | |
230 | assert isinstance(pets_field, Dynamic) | |
231 | assert isinstance(pets_field.type().type, List) | |
232 | assert pets_field.type().type.of_type == PetType | |
233 | assert pets_field.type().description == 'Overridden' | |
234 | ||
235 | ||
236 | def test_invalid_model_attr(): | |
237 | err_msg = ( | |
238 | "Cannot map ORMField to a model attribute.\n" | |
239 | "Field: 'ReporterType.first_name'" | |
240 | ) | |
241 | with pytest.raises(ValueError, match=err_msg): | |
242 | class ReporterType(SQLAlchemyObjectType): | |
243 | class Meta: | |
244 | model = Reporter | |
245 | ||
246 | first_name = ORMField(model_attr='does_not_exist') | |
247 | ||
248 | ||
249 | def test_only_fields(): | |
250 | class ReporterType(SQLAlchemyObjectType): | |
251 | class Meta: | |
252 | model = Reporter | |
253 | only_fields = ("id", "last_name") | |
254 | ||
255 | first_name = ORMField() # Takes precedence | |
256 | last_name = ORMField() # Noop | |
257 | ||
258 | assert list(ReporterType._meta.fields.keys()) == ["first_name", "last_name", "id"] | |
259 | ||
260 | ||
261 | def test_exclude_fields(): | |
262 | class ReporterType(SQLAlchemyObjectType): | |
263 | class Meta: | |
264 | model = Reporter | |
265 | exclude_fields = ("id", "first_name") | |
266 | ||
267 | first_name = ORMField() # Takes precedence | |
268 | last_name = ORMField() # Noop | |
269 | ||
270 | assert list(ReporterType._meta.fields.keys()) == [ | |
271 | "first_name", | |
272 | "last_name", | |
273 | "column_prop", | |
274 | "email", | |
275 | "favorite_pet_kind", | |
276 | "composite_prop", | |
277 | "hybrid_prop", | |
126 | 278 | "pets", |
127 | 279 | "articles", |
128 | 280 | "favorite_article", |
129 | 281 | ] |
130 | 282 | |
131 | 283 | |
284 | def test_only_and_exclude_fields(): | |
285 | re_err = r"'only_fields' and 'exclude_fields' cannot be both set" | |
286 | with pytest.raises(Exception, match=re_err): | |
287 | class ReporterType(SQLAlchemyObjectType): | |
288 | class Meta: | |
289 | model = Reporter | |
290 | only_fields = ("id", "last_name") | |
291 | exclude_fields = ("id", "last_name") | |
292 | ||
293 | ||
294 | def test_sqlalchemy_redefine_field(): | |
295 | class ReporterType(SQLAlchemyObjectType): | |
296 | class Meta: | |
297 | model = Reporter | |
298 | ||
299 | first_name = Int() | |
300 | ||
301 | first_name_field = ReporterType._meta.fields["first_name"] | |
302 | assert isinstance(first_name_field, Field) | |
303 | assert first_name_field.type == Int | |
304 | ||
305 | ||
306 | def test_resolvers(session): | |
307 | """Test that the correct resolver functions are called""" | |
308 | ||
309 | class ReporterMixin(object): | |
310 | def resolve_id(root, _info): | |
311 | return 'ID' | |
312 | ||
313 | class ReporterType(ReporterMixin, SQLAlchemyObjectType): | |
314 | class Meta: | |
315 | model = Reporter | |
316 | ||
317 | email = ORMField() | |
318 | email_v2 = ORMField(model_attr='email') | |
319 | favorite_pet_kind = Field(String) | |
320 | favorite_pet_kind_v2 = Field(String) | |
321 | ||
322 | def resolve_last_name(root, _info): | |
323 | return root.last_name.upper() | |
324 | ||
325 | def resolve_email_v2(root, _info): | |
326 | return root.email + '_V2' | |
327 | ||
328 | def resolve_favorite_pet_kind_v2(root, _info): | |
329 | return str(root.favorite_pet_kind) + '_V2' | |
330 | ||
331 | class Query(ObjectType): | |
332 | reporter = Field(ReporterType) | |
333 | ||
334 | def resolve_reporter(self, _info): | |
335 | return session.query(Reporter).first() | |
336 | ||
337 | reporter = Reporter(first_name='first_name', last_name='last_name', email='email', favorite_pet_kind='cat') | |
338 | session.add(reporter) | |
339 | session.commit() | |
340 | ||
341 | schema = Schema(query=Query) | |
342 | result = schema.execute(""" | |
343 | query { | |
344 | reporter { | |
345 | id | |
346 | firstName | |
347 | lastName | |
348 | ||
349 | emailV2 | |
350 | favoritePetKind | |
351 | favoritePetKindV2 | |
352 | } | |
353 | } | |
354 | """) | |
355 | ||
356 | assert not result.errors | |
357 | # Custom resolver on a base class | |
358 | assert result.data['reporter']['id'] == 'ID' | |
359 | # Default field + default resolver | |
360 | assert result.data['reporter']['firstName'] == 'first_name' | |
361 | # Default field + custom resolver | |
362 | assert result.data['reporter']['lastName'] == 'LAST_NAME' | |
363 | # ORMField + default resolver | |
364 | assert result.data['reporter']['email'] == 'email' | |
365 | # ORMField + custom resolver | |
366 | assert result.data['reporter']['emailV2'] == 'email_V2' | |
367 | # Field + default resolver | |
368 | assert result.data['reporter']['favoritePetKind'] == 'cat' | |
369 | # Field + custom resolver | |
370 | assert result.data['reporter']['favoritePetKindV2'] == 'cat_V2' | |
371 | ||
372 | ||
373 | # Test Custom SQLAlchemyObjectType Implementation | |
374 | ||
375 | def test_custom_objecttype_registered(): | |
376 | class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): | |
377 | class Meta: | |
378 | abstract = True | |
379 | ||
380 | class CustomReporterType(CustomSQLAlchemyObjectType): | |
381 | class Meta: | |
382 | model = Reporter | |
383 | ||
384 | assert issubclass(CustomReporterType, ObjectType) | |
385 | assert CustomReporterType._meta.model == Reporter | |
386 | assert len(CustomReporterType._meta.fields) == 11 | |
387 | ||
388 | ||
132 | 389 | # Test Custom SQLAlchemyObjectType with Custom Options |
133 | class CustomOptions(SQLAlchemyObjectTypeOptions): | |
134 | custom_option = None | |
135 | custom_fields = None | |
136 | ||
137 | ||
138 | class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): | |
139 | class Meta: | |
140 | abstract = True | |
141 | ||
142 | @classmethod | |
143 | def __init_subclass_with_meta__( | |
144 | cls, custom_option=None, custom_fields=None, **options | |
145 | ): | |
146 | _meta = CustomOptions(cls) | |
147 | _meta.custom_option = custom_option | |
148 | _meta.fields = custom_fields | |
149 | super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( | |
150 | _meta=_meta, **options | |
151 | ) | |
152 | ||
153 | ||
154 | class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): | |
155 | class Meta: | |
156 | model = Reporter | |
157 | custom_option = "custom_option" | |
158 | custom_fields = OrderedDict([("custom_field", Field(Int()))]) | |
159 | ||
160 | ||
161 | 390 | def test_objecttype_with_custom_options(): |
391 | class CustomOptions(SQLAlchemyObjectTypeOptions): | |
392 | custom_option = None | |
393 | ||
394 | class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): | |
395 | class Meta: | |
396 | abstract = True | |
397 | ||
398 | @classmethod | |
399 | def __init_subclass_with_meta__(cls, custom_option=None, **options): | |
400 | _meta = CustomOptions(cls) | |
401 | _meta.custom_option = custom_option | |
402 | super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( | |
403 | _meta=_meta, **options | |
404 | ) | |
405 | ||
406 | class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): | |
407 | class Meta: | |
408 | model = Reporter | |
409 | custom_option = "custom_option" | |
410 | ||
162 | 411 | assert issubclass(ReporterWithCustomOptions, ObjectType) |
163 | 412 | assert ReporterWithCustomOptions._meta.model == Reporter |
164 | assert list(ReporterWithCustomOptions._meta.fields.keys()) == [ | |
165 | "custom_field", | |
166 | "id", | |
167 | "first_name", | |
168 | "last_name", | |
169 | "email", | |
170 | "pets", | |
171 | "articles", | |
172 | "favorite_article", | |
173 | ] | |
174 | 413 | assert ReporterWithCustomOptions._meta.custom_option == "custom_option" |
175 | assert isinstance(ReporterWithCustomOptions._meta.fields["custom_field"].type, Int) | |
176 | ||
177 | ||
178 | def test_promise_connection_resolver(): | |
179 | class TestConnection(Connection): | |
180 | class Meta: | |
181 | node = ReporterWithCustomOptions | |
182 | ||
183 | def resolver(*args, **kwargs): | |
184 | return Promise.resolve([]) | |
185 | ||
186 | result = SQLAlchemyConnectionField.connection_resolver( | |
187 | resolver, TestConnection, ReporterWithCustomOptions, None, None | |
188 | ) | |
189 | assert result is not None | |
190 | 414 | |
191 | 415 | |
192 | 416 | # Tests for connection_field_factory |
196 | 420 | |
197 | 421 | |
198 | 422 | def test_default_connection_field_factory(): |
199 | _registry = Registry() | |
200 | ||
201 | class ReporterType(SQLAlchemyObjectType): | |
202 | class Meta: | |
203 | model = Reporter | |
204 | registry = _registry | |
423 | class ReporterType(SQLAlchemyObjectType): | |
424 | class Meta: | |
425 | model = Reporter | |
205 | 426 | interfaces = (Node,) |
206 | 427 | |
207 | 428 | class ArticleType(SQLAlchemyObjectType): |
208 | 429 | class Meta: |
209 | 430 | model = Article |
210 | registry = _registry | |
211 | 431 | interfaces = (Node,) |
212 | 432 | |
213 | 433 | assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) |
214 | 434 | |
215 | 435 | |
216 | def test_register_connection_field_factory(): | |
436 | def test_custom_connection_field_factory(): | |
217 | 437 | def test_connection_field_factory(relationship, registry): |
218 | 438 | model = relationship.mapper.entity |
219 | 439 | _type = registry.get_type_for_model(model) |
220 | 440 | return _TestSQLAlchemyConnectionField(_type._meta.connection) |
221 | 441 | |
222 | _registry = Registry() | |
223 | ||
224 | class ReporterType(SQLAlchemyObjectType): | |
225 | class Meta: | |
226 | model = Reporter | |
227 | registry = _registry | |
442 | class ReporterType(SQLAlchemyObjectType): | |
443 | class Meta: | |
444 | model = Reporter | |
228 | 445 | interfaces = (Node,) |
229 | 446 | connection_field_factory = test_connection_field_factory |
230 | 447 | |
231 | 448 | class ArticleType(SQLAlchemyObjectType): |
232 | 449 | class Meta: |
233 | 450 | model = Article |
234 | registry = _registry | |
235 | 451 | interfaces = (Node,) |
236 | 452 | |
237 | 453 | assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) |
238 | 454 | |
239 | 455 | |
240 | 456 | def test_deprecated_registerConnectionFieldFactory(): |
241 | registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) | |
242 | ||
243 | _registry = Registry() | |
244 | ||
245 | class ReporterType(SQLAlchemyObjectType): | |
246 | class Meta: | |
247 | model = Reporter | |
248 | registry = _registry | |
249 | interfaces = (Node,) | |
250 | ||
251 | class ArticleType(SQLAlchemyObjectType): | |
252 | class Meta: | |
253 | model = Article | |
254 | registry = _registry | |
255 | interfaces = (Node,) | |
256 | ||
257 | assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) | |
457 | with pytest.warns(DeprecationWarning): | |
458 | registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) | |
459 | ||
460 | class ReporterType(SQLAlchemyObjectType): | |
461 | class Meta: | |
462 | model = Reporter | |
463 | interfaces = (Node,) | |
464 | ||
465 | class ArticleType(SQLAlchemyObjectType): | |
466 | class Meta: | |
467 | model = Article | |
468 | interfaces = (Node,) | |
469 | ||
470 | assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) | |
258 | 471 | |
259 | 472 | |
260 | 473 | def test_deprecated_unregisterConnectionFieldFactory(): |
261 | registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) | |
262 | unregisterConnectionFieldFactory() | |
263 | ||
264 | _registry = Registry() | |
265 | ||
266 | class ReporterType(SQLAlchemyObjectType): | |
267 | class Meta: | |
268 | model = Reporter | |
269 | registry = _registry | |
270 | interfaces = (Node,) | |
271 | ||
272 | class ArticleType(SQLAlchemyObjectType): | |
273 | class Meta: | |
274 | model = Article | |
275 | registry = _registry | |
276 | interfaces = (Node,) | |
277 | ||
278 | assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) | |
474 | with pytest.warns(DeprecationWarning): | |
475 | registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) | |
476 | unregisterConnectionFieldFactory() | |
477 | ||
478 | class ReporterType(SQLAlchemyObjectType): | |
479 | class Meta: | |
480 | model = Reporter | |
481 | interfaces = (Node,) | |
482 | ||
483 | class ArticleType(SQLAlchemyObjectType): | |
484 | class Meta: | |
485 | model = Article | |
486 | interfaces = (Node,) | |
487 | ||
488 | assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) | |
489 | ||
490 | ||
491 | def test_deprecated_createConnectionField(): | |
492 | with pytest.warns(DeprecationWarning): | |
493 | createConnectionField(None) |
0 | import pytest | |
0 | 1 | import sqlalchemy as sa |
1 | 2 | |
2 | 3 | from graphene import Enum, List, ObjectType, Schema, String |
3 | 4 | |
4 | from ..utils import get_session, sort_argument_for_model, sort_enum_for_model | |
5 | from .models import Editor, Pet | |
5 | from ..utils import (get_session, sort_argument_for_model, sort_enum_for_model, | |
6 | to_enum_value_name, to_type_name) | |
7 | from .models import Base, Editor, Pet | |
6 | 8 | |
7 | 9 | |
8 | 10 | def test_get_session(): |
26 | 28 | assert result.data["x"] == session |
27 | 29 | |
28 | 30 | |
31 | def test_to_type_name(): | |
32 | assert to_type_name("make_camel_case") == "MakeCamelCase" | |
33 | assert to_type_name("AlreadyCamelCase") == "AlreadyCamelCase" | |
34 | assert to_type_name("A_Snake_and_a_Camel") == "ASnakeAndACamel" | |
35 | ||
36 | ||
37 | def test_to_enum_value_name(): | |
38 | assert to_enum_value_name("make_enum_value_name") == "MAKE_ENUM_VALUE_NAME" | |
39 | assert to_enum_value_name("makeEnumValueName") == "MAKE_ENUM_VALUE_NAME" | |
40 | assert to_enum_value_name("HTTPStatus400Message") == "HTTP_STATUS400_MESSAGE" | |
41 | assert to_enum_value_name("ALREADY_ENUM_VALUE_NAME") == "ALREADY_ENUM_VALUE_NAME" | |
42 | ||
43 | ||
44 | # test deprecated sort enum utility functions | |
45 | ||
46 | ||
29 | 47 | def test_sort_enum_for_model(): |
30 | enum = sort_enum_for_model(Pet) | |
48 | with pytest.warns(DeprecationWarning): | |
49 | enum = sort_enum_for_model(Pet) | |
31 | 50 | assert isinstance(enum, type(Enum)) |
32 | 51 | assert str(enum) == "PetSortEnum" |
33 | 52 | for col in sa.inspect(Pet).columns: |
36 | 55 | |
37 | 56 | |
38 | 57 | def test_sort_enum_for_model_custom_naming(): |
39 | enum = sort_enum_for_model(Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D")) | |
58 | with pytest.warns(DeprecationWarning): | |
59 | enum = sort_enum_for_model( | |
60 | Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D") | |
61 | ) | |
40 | 62 | assert str(enum) == "Foo" |
41 | 63 | for col in sa.inspect(Pet).columns: |
42 | 64 | assert hasattr(enum, col.name.upper() + "A") |
44 | 66 | |
45 | 67 | |
46 | 68 | def test_enum_cache(): |
47 | assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor) | |
69 | with pytest.warns(DeprecationWarning): | |
70 | assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor) | |
48 | 71 | |
49 | 72 | |
50 | 73 | def test_sort_argument_for_model(): |
51 | arg = sort_argument_for_model(Pet) | |
74 | with pytest.warns(DeprecationWarning): | |
75 | arg = sort_argument_for_model(Pet) | |
52 | 76 | |
53 | 77 | assert isinstance(arg.type, List) |
54 | 78 | assert arg.default_value == [Pet.id.name + "_asc"] |
55 | assert arg.type.of_type == sort_enum_for_model(Pet) | |
79 | with pytest.warns(DeprecationWarning): | |
80 | assert arg.type.of_type is sort_enum_for_model(Pet) | |
56 | 81 | |
57 | 82 | |
58 | 83 | def test_sort_argument_for_model_no_default(): |
59 | arg = sort_argument_for_model(Pet, False) | |
84 | with pytest.warns(DeprecationWarning): | |
85 | arg = sort_argument_for_model(Pet, False) | |
60 | 86 | |
61 | 87 | assert arg.default_value is None |
62 | 88 | |
63 | 89 | |
64 | 90 | def test_sort_argument_for_model_multiple_pk(): |
65 | Base = sa.ext.declarative.declarative_base() | |
66 | ||
67 | 91 | class MultiplePK(Base): |
68 | 92 | foo = sa.Column(sa.Integer, primary_key=True) |
69 | 93 | bar = sa.Column(sa.Integer, primary_key=True) |
70 | 94 | __tablename__ = "MultiplePK" |
71 | 95 | |
72 | arg = sort_argument_for_model(MultiplePK) | |
96 | with pytest.warns(DeprecationWarning): | |
97 | arg = sort_argument_for_model(MultiplePK) | |
73 | 98 | assert set(arg.default_value) == set( |
74 | 99 | (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") |
75 | 100 | ) |
0 | import pkg_resources | |
1 | ||
2 | ||
3 | def to_std_dicts(value): | |
4 | """Convert nested ordered dicts to normal dicts for better comparison.""" | |
5 | if isinstance(value, dict): | |
6 | return {k: to_std_dicts(v) for k, v in value.items()} | |
7 | elif isinstance(value, list): | |
8 | return [to_std_dicts(v) for v in value] | |
9 | else: | |
10 | return value | |
11 | ||
12 | ||
13 | def is_sqlalchemy_version_less_than(version_string): | |
14 | """Check the installed SQLAlchemy version""" | |
15 | return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) |
1 | 1 | |
2 | 2 | import sqlalchemy |
3 | 3 | from sqlalchemy.ext.hybrid import hybrid_property |
4 | from sqlalchemy.inspection import inspect as sqlalchemyinspect | |
4 | from sqlalchemy.orm import (ColumnProperty, CompositeProperty, | |
5 | RelationshipProperty) | |
5 | 6 | from sqlalchemy.orm.exc import NoResultFound |
6 | 7 | |
7 | from graphene import Field # , annotate, ResolveInfo | |
8 | from graphene import Field | |
8 | 9 | from graphene.relay import Connection, Node |
9 | 10 | from graphene.types.objecttype import ObjectType, ObjectTypeOptions |
10 | 11 | from graphene.types.utils import yank_fields_from_attrs |
12 | from graphene.utils.orderedtype import OrderedType | |
11 | 13 | |
12 | 14 | from .converter import (convert_sqlalchemy_column, |
13 | 15 | convert_sqlalchemy_composite, |
14 | 16 | convert_sqlalchemy_hybrid_method, |
15 | 17 | convert_sqlalchemy_relationship) |
16 | from .fields import default_connection_field_factory | |
18 | from .enums import (enum_for_field, sort_argument_for_object_type, | |
19 | sort_enum_for_object_type) | |
17 | 20 | from .registry import Registry, get_global_registry |
21 | from .resolvers import get_attr_resolver, get_custom_resolver | |
18 | 22 | from .utils import get_query, is_mapped_class, is_mapped_instance |
19 | 23 | |
20 | 24 | |
21 | def construct_fields(model, registry, only_fields, exclude_fields, connection_field_factory): | |
22 | inspected_model = sqlalchemyinspect(model) | |
23 | ||
25 | class ORMField(OrderedType): | |
26 | def __init__( | |
27 | self, | |
28 | model_attr=None, | |
29 | type=None, | |
30 | required=None, | |
31 | description=None, | |
32 | deprecation_reason=None, | |
33 | batching=None, | |
34 | _creation_counter=None, | |
35 | **field_kwargs | |
36 | ): | |
37 | """ | |
38 | Use this to override fields automatically generated by SQLAlchemyObjectType. | |
39 | Unless specified, options will default to SQLAlchemyObjectType usual behavior | |
40 | for the given SQLAlchemy model property. | |
41 | ||
42 | Usage: | |
43 | class MyModel(Base): | |
44 | id = Column(Integer(), primary_key=True) | |
45 | name = Column(String) | |
46 | ||
47 | class MyType(SQLAlchemyObjectType): | |
48 | class Meta: | |
49 | model = MyModel | |
50 | ||
51 | id = ORMField(type=graphene.Int) | |
52 | name = ORMField(required=True) | |
53 | ||
54 | -> MyType.id will be of type Int (vs ID). | |
55 | -> MyType.name will be of type NonNull(String) (vs String). | |
56 | ||
57 | :param str model_attr: | |
58 | Name of the SQLAlchemy model attribute used to resolve this field. | |
59 | Default to the name of the attribute referencing the ORMField. | |
60 | :param type: | |
61 | Default to the type mapping in converter.py. | |
62 | :param str description: | |
63 | Default to the `doc` attribute of the SQLAlchemy column property. | |
64 | :param bool required: | |
65 | Default to the opposite of the `nullable` attribute of the SQLAlchemy column property. | |
66 | :param str description: | |
67 | Same behavior as in graphene.Field. Defaults to None. | |
68 | :param str deprecation_reason: | |
69 | Same behavior as in graphene.Field. Defaults to None. | |
70 | :param bool batching: | |
71 | Toggle SQL batching. Defaults to None, that is `SQLAlchemyObjectType.meta.batching`. | |
72 | :param int _creation_counter: | |
73 | Same behavior as in graphene.Field. | |
74 | """ | |
75 | super(ORMField, self).__init__(_creation_counter=_creation_counter) | |
76 | # The is only useful for documentation and auto-completion | |
77 | common_kwargs = { | |
78 | 'model_attr': model_attr, | |
79 | 'type': type, | |
80 | 'required': required, | |
81 | 'description': description, | |
82 | 'deprecation_reason': deprecation_reason, | |
83 | 'batching': batching, | |
84 | } | |
85 | common_kwargs = {kwarg: value for kwarg, value in common_kwargs.items() if value is not None} | |
86 | self.kwargs = field_kwargs | |
87 | self.kwargs.update(common_kwargs) | |
88 | ||
89 | ||
90 | def construct_fields( | |
91 | obj_type, model, registry, only_fields, exclude_fields, batching, connection_field_factory | |
92 | ): | |
93 | """ | |
94 | Construct all the fields for a SQLAlchemyObjectType. | |
95 | The main steps are: | |
96 | - Gather all the relevant attributes from the SQLAlchemy model | |
97 | - Gather all the ORM fields defined on the type | |
98 | - Merge in overrides and build up all the fields | |
99 | ||
100 | :param SQLAlchemyObjectType obj_type: | |
101 | :param model: the SQLAlchemy model | |
102 | :param Registry registry: | |
103 | :param tuple[string] only_fields: | |
104 | :param tuple[string] exclude_fields: | |
105 | :param bool batching: | |
106 | :param function|None connection_field_factory: | |
107 | :rtype: OrderedDict[str, graphene.Field] | |
108 | """ | |
109 | inspected_model = sqlalchemy.inspect(model) | |
110 | # Gather all the relevant attributes from the SQLAlchemy model in order | |
111 | all_model_attrs = OrderedDict( | |
112 | inspected_model.column_attrs.items() + | |
113 | inspected_model.composites.items() + | |
114 | [(name, item) for name, item in inspected_model.all_orm_descriptors.items() | |
115 | if isinstance(item, hybrid_property)] + | |
116 | inspected_model.relationships.items() | |
117 | ) | |
118 | ||
119 | # Filter out excluded fields | |
120 | auto_orm_field_names = [] | |
121 | for attr_name, attr in all_model_attrs.items(): | |
122 | if (only_fields and attr_name not in only_fields) or (attr_name in exclude_fields): | |
123 | continue | |
124 | auto_orm_field_names.append(attr_name) | |
125 | ||
126 | # Gather all the ORM fields defined on the type | |
127 | custom_orm_fields_items = [ | |
128 | (attn_name, attr) | |
129 | for base in reversed(obj_type.__mro__) | |
130 | for attn_name, attr in base.__dict__.items() | |
131 | if isinstance(attr, ORMField) | |
132 | ] | |
133 | custom_orm_fields_items = sorted(custom_orm_fields_items, key=lambda item: item[1]) | |
134 | ||
135 | # Set the model_attr if not set | |
136 | for orm_field_name, orm_field in custom_orm_fields_items: | |
137 | attr_name = orm_field.kwargs.get('model_attr', orm_field_name) | |
138 | if attr_name not in all_model_attrs: | |
139 | raise ValueError(( | |
140 | "Cannot map ORMField to a model attribute.\n" | |
141 | "Field: '{}.{}'" | |
142 | ).format(obj_type.__name__, orm_field_name,)) | |
143 | orm_field.kwargs['model_attr'] = attr_name | |
144 | ||
145 | # Merge automatic fields with custom ORM fields | |
146 | orm_fields = OrderedDict(custom_orm_fields_items) | |
147 | for orm_field_name in auto_orm_field_names: | |
148 | if orm_field_name in orm_fields: | |
149 | continue | |
150 | orm_fields[orm_field_name] = ORMField(model_attr=orm_field_name) | |
151 | ||
152 | # Build all the field dictionary | |
24 | 153 | fields = OrderedDict() |
25 | ||
26 | for name, column in inspected_model.columns.items(): | |
27 | is_not_in_only = only_fields and name not in only_fields | |
28 | # is_already_created = name in options.fields | |
29 | is_excluded = name in exclude_fields # or is_already_created | |
30 | if is_not_in_only or is_excluded: | |
31 | # We skip this field if we specify only_fields and is not | |
32 | # in there. Or when we exclude this field in exclude_fields | |
33 | continue | |
34 | converted_column = convert_sqlalchemy_column(column, registry) | |
35 | fields[name] = converted_column | |
36 | ||
37 | for name, composite in inspected_model.composites.items(): | |
38 | is_not_in_only = only_fields and name not in only_fields | |
39 | # is_already_created = name in options.fields | |
40 | is_excluded = name in exclude_fields # or is_already_created | |
41 | if is_not_in_only or is_excluded: | |
42 | # We skip this field if we specify only_fields and is not | |
43 | # in there. Or when we exclude this field in exclude_fields | |
44 | continue | |
45 | converted_composite = convert_sqlalchemy_composite(composite, registry) | |
46 | fields[name] = converted_composite | |
47 | ||
48 | for hybrid_item in inspected_model.all_orm_descriptors: | |
49 | ||
50 | if type(hybrid_item) == hybrid_property: | |
51 | name = hybrid_item.__name__ | |
52 | ||
53 | is_not_in_only = only_fields and name not in only_fields | |
54 | # is_already_created = name in options.fields | |
55 | is_excluded = name in exclude_fields # or is_already_created | |
56 | ||
57 | if is_not_in_only or is_excluded: | |
58 | # We skip this field if we specify only_fields and is not | |
59 | # in there. Or when we exclude this field in exclude_fields | |
60 | continue | |
61 | ||
62 | converted_hybrid_property = convert_sqlalchemy_hybrid_method(hybrid_item) | |
63 | fields[name] = converted_hybrid_property | |
64 | ||
65 | # Get all the columns for the relationships on the model | |
66 | for relationship in inspected_model.relationships: | |
67 | is_not_in_only = only_fields and relationship.key not in only_fields | |
68 | # is_already_created = relationship.key in options.fields | |
69 | is_excluded = relationship.key in exclude_fields # or is_already_created | |
70 | if is_not_in_only or is_excluded: | |
71 | # We skip this field if we specify only_fields and is not | |
72 | # in there. Or when we exclude this field in exclude_fields | |
73 | continue | |
74 | converted_relationship = convert_sqlalchemy_relationship(relationship, registry, connection_field_factory) | |
75 | name = relationship.key | |
76 | fields[name] = converted_relationship | |
154 | for orm_field_name, orm_field in orm_fields.items(): | |
155 | attr_name = orm_field.kwargs.pop('model_attr') | |
156 | attr = all_model_attrs[attr_name] | |
157 | resolver = get_custom_resolver(obj_type, orm_field_name) or get_attr_resolver(obj_type, attr_name) | |
158 | ||
159 | if isinstance(attr, ColumnProperty): | |
160 | field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs) | |
161 | elif isinstance(attr, RelationshipProperty): | |
162 | batching_ = orm_field.kwargs.pop('batching', batching) | |
163 | field = convert_sqlalchemy_relationship( | |
164 | attr, obj_type, connection_field_factory, batching_, orm_field_name, **orm_field.kwargs) | |
165 | elif isinstance(attr, CompositeProperty): | |
166 | if attr_name != orm_field_name or orm_field.kwargs: | |
167 | # TODO Add a way to override composite property fields | |
168 | raise ValueError( | |
169 | "ORMField kwargs for composite fields must be empty. " | |
170 | "Field: {}.{}".format(obj_type.__name__, orm_field_name)) | |
171 | field = convert_sqlalchemy_composite(attr, registry, resolver) | |
172 | elif isinstance(attr, hybrid_property): | |
173 | field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs) | |
174 | else: | |
175 | raise Exception('Property type is not supported') # Should never happen | |
176 | ||
177 | registry.register_orm_field(obj_type, orm_field_name, attr) | |
178 | fields[orm_field_name] = field | |
77 | 179 | |
78 | 180 | return fields |
79 | 181 | |
99 | 201 | use_connection=None, |
100 | 202 | interfaces=(), |
101 | 203 | id=None, |
102 | connection_field_factory=default_connection_field_factory, | |
204 | batching=False, | |
205 | connection_field_factory=None, | |
103 | 206 | _meta=None, |
104 | 207 | **options |
105 | 208 | ): |
115 | 218 | 'Registry, received "{}".' |
116 | 219 | ).format(cls.__name__, registry) |
117 | 220 | |
221 | if only_fields and exclude_fields: | |
222 | raise ValueError("The options 'only_fields' and 'exclude_fields' cannot be both set on the same type.") | |
223 | ||
118 | 224 | sqla_fields = yank_fields_from_attrs( |
119 | 225 | construct_fields( |
226 | obj_type=cls, | |
120 | 227 | model=model, |
121 | 228 | registry=registry, |
122 | 229 | only_fields=only_fields, |
123 | 230 | exclude_fields=exclude_fields, |
124 | connection_field_factory=connection_field_factory | |
231 | batching=batching, | |
232 | connection_field_factory=connection_field_factory, | |
125 | 233 | ), |
126 | _as=Field | |
234 | _as=Field, | |
235 | sort=False, | |
127 | 236 | ) |
128 | 237 | |
129 | 238 | if use_connection is None and interfaces: |
158 | 267 | |
159 | 268 | _meta.connection = connection |
160 | 269 | _meta.id = id or "id" |
270 | ||
271 | cls.connection = connection # Public way to get the connection | |
161 | 272 | |
162 | 273 | super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( |
163 | 274 | _meta=_meta, interfaces=interfaces, **options |
190 | 301 | # graphene_type = info.parent_type.graphene_type |
191 | 302 | keys = self.__mapper__.primary_key_from_instance(self) |
192 | 303 | return tuple(keys) if len(keys) > 1 else keys[0] |
304 | ||
305 | @classmethod | |
306 | def enum_for_field(cls, field_name): | |
307 | return enum_for_field(cls, field_name) | |
308 | ||
309 | sort_enum = classmethod(sort_enum_for_object_type) | |
310 | ||
311 | sort_argument = classmethod(sort_argument_for_object_type) |
0 | import re | |
1 | import warnings | |
2 | ||
0 | 3 | from sqlalchemy.exc import ArgumentError |
1 | from sqlalchemy.inspection import inspect | |
2 | 4 | from sqlalchemy.orm import class_mapper, object_mapper |
3 | 5 | from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError |
4 | ||
5 | from graphene import Argument, Enum, List | |
6 | 6 | |
7 | 7 | |
8 | 8 | def get_session(context): |
40 | 40 | return True |
41 | 41 | |
42 | 42 | |
43 | def _symbol_name(column_name, is_asc): | |
44 | return column_name + ("_asc" if is_asc else "_desc") | |
43 | def to_type_name(name): | |
44 | """Convert the given name to a GraphQL type name.""" | |
45 | return "".join(part[:1].upper() + part[1:] for part in name.split("_")) | |
46 | ||
47 | ||
48 | _re_enum_value_name_1 = re.compile("(.)([A-Z][a-z]+)") | |
49 | _re_enum_value_name_2 = re.compile("([a-z0-9])([A-Z])") | |
50 | ||
51 | ||
52 | def to_enum_value_name(name): | |
53 | """Convert the given name to a GraphQL enum value name.""" | |
54 | return _re_enum_value_name_2.sub( | |
55 | r"\1_\2", _re_enum_value_name_1.sub(r"\1_\2", name) | |
56 | ).upper() | |
45 | 57 | |
46 | 58 | |
47 | 59 | class EnumValue(str): |
48 | """Subclass of str that stores a string and an arbitrary value in the "value" property""" | |
60 | """String that has an additional value attached. | |
49 | 61 | |
50 | def __new__(cls, str_value, value): | |
51 | return super(EnumValue, cls).__new__(cls, str_value) | |
62 | This is used to attach SQLAlchemy model columns to Enum symbols. | |
63 | """ | |
52 | 64 | |
53 | def __init__(self, str_value, value): | |
65 | def __new__(cls, s, value): | |
66 | return super(EnumValue, cls).__new__(cls, s) | |
67 | ||
68 | def __init__(self, _s, value): | |
54 | 69 | super(EnumValue, self).__init__() |
55 | 70 | self.value = value |
56 | 71 | |
57 | 72 | |
58 | # Cache for the generated enums, to avoid name clash | |
59 | _ENUM_CACHE = {} | |
73 | def _deprecated_default_symbol_name(column_name, sort_asc): | |
74 | return column_name + ("_asc" if sort_asc else "_desc") | |
60 | 75 | |
61 | 76 | |
62 | def _sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): | |
63 | name = name or cls.__name__ + "SortEnum" | |
64 | if name in _ENUM_CACHE: | |
65 | return _ENUM_CACHE[name] | |
66 | items = [] | |
67 | default = [] | |
68 | for column in inspect(cls).columns.values(): | |
69 | asc_name = symbol_name(column.name, True) | |
70 | asc_value = EnumValue(asc_name, column.asc()) | |
71 | desc_name = symbol_name(column.name, False) | |
72 | desc_value = EnumValue(desc_name, column.desc()) | |
73 | if column.primary_key: | |
74 | default.append(asc_value) | |
75 | items.extend(((asc_name, asc_value), (desc_name, desc_value))) | |
76 | enum = Enum(name, items) | |
77 | _ENUM_CACHE[name] = (enum, default) | |
78 | return enum, default | |
77 | # unfortunately, we cannot use lru_cache because we still support Python 2 | |
78 | _deprecated_object_type_cache = {} | |
79 | 79 | |
80 | 80 | |
81 | def sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): | |
82 | """Create Graphene Enum for sorting a SQLAlchemy class query | |
81 | def _deprecated_object_type_for_model(cls, name): | |
83 | 82 | |
84 | Parameters | |
85 | - cls : Sqlalchemy model class | |
86 | Model used to create the sort enumerator | |
87 | - name : str, optional, default None | |
88 | Name to use for the enumerator. If not provided it will be set to `cls.__name__ + 'SortEnum'` | |
89 | - symbol_name : function, optional, default `_symbol_name` | |
90 | Function which takes the column name and a boolean indicating if the sort direction is ascending, | |
91 | and returns the symbol name for the current column and sort direction. | |
92 | The default function will create, for a column named 'foo', the symbols 'foo_asc' and 'foo_desc' | |
83 | try: | |
84 | return _deprecated_object_type_cache[cls, name] | |
85 | except KeyError: | |
86 | from .types import SQLAlchemyObjectType | |
93 | 87 | |
94 | Returns | |
95 | - Enum | |
96 | The Graphene enumerator | |
88 | obj_type_name = name or cls.__name__ | |
89 | ||
90 | class ObjType(SQLAlchemyObjectType): | |
91 | class Meta: | |
92 | name = obj_type_name | |
93 | model = cls | |
94 | ||
95 | _deprecated_object_type_cache[cls, name] = ObjType | |
96 | return ObjType | |
97 | ||
98 | ||
99 | def sort_enum_for_model(cls, name=None, symbol_name=None): | |
100 | """Get a Graphene Enum for sorting the given model class. | |
101 | ||
102 | This is deprecated, please use object_type.sort_enum() instead. | |
97 | 103 | """ |
98 | enum, _ = _sort_enum_for_model(cls, name, symbol_name) | |
99 | return enum | |
104 | warnings.warn( | |
105 | "sort_enum_for_model() is deprecated; use object_type.sort_enum() instead.", | |
106 | DeprecationWarning, | |
107 | stacklevel=2, | |
108 | ) | |
109 | ||
110 | from .enums import sort_enum_for_object_type | |
111 | ||
112 | return sort_enum_for_object_type( | |
113 | _deprecated_object_type_for_model(cls, name), | |
114 | name, | |
115 | get_symbol_name=symbol_name or _deprecated_default_symbol_name, | |
116 | ) | |
100 | 117 | |
101 | 118 | |
102 | 119 | def sort_argument_for_model(cls, has_default=True): |
103 | """Returns a Graphene argument for the sort field that accepts a list of sorting directions for a model. | |
104 | If `has_default` is True (the default) it will sort the result by the primary key(s) | |
120 | """Get a Graphene Argument for sorting the given model class. | |
121 | ||
122 | This is deprecated, please use object_type.sort_argument() instead. | |
105 | 123 | """ |
106 | enum, default = _sort_enum_for_model(cls) | |
124 | warnings.warn( | |
125 | "sort_argument_for_model() is deprecated;" | |
126 | " use object_type.sort_argument() instead.", | |
127 | DeprecationWarning, | |
128 | stacklevel=2, | |
129 | ) | |
130 | ||
131 | from graphene import Argument, List | |
132 | from .enums import sort_enum_for_object_type | |
133 | ||
134 | enum = sort_enum_for_object_type( | |
135 | _deprecated_object_type_for_model(cls, None), | |
136 | get_symbol_name=_deprecated_default_symbol_name, | |
137 | ) | |
107 | 138 | if not has_default: |
108 | default = None | |
109 | return Argument(List(enum), default_value=default) | |
139 | enum.default = None | |
140 | ||
141 | return Argument(List(enum), default_value=enum.default) |
5 | 5 | max-line-length = 120 |
6 | 6 | |
7 | 7 | [isort] |
8 | no_lines_before=FIRSTPARTY | |
8 | 9 | known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme |
9 | 10 | known_first_party=graphene_sqlalchemy |
10 | known_third_party=flask,nameko,promise,py,pytest,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils | |
11 | known_third_party=app,database,flask,graphql,mock,models,nameko,pkg_resources,promise,pytest,schema,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils | |
11 | 12 | sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER |
12 | no_lines_before=FIRSTPARTY | |
13 | skip_glob=examples/nameko_sqlalchemy | |
13 | 14 | |
14 | 15 | [bdist_wheel] |
15 | 16 | universal=1 |
13 | 13 | requirements = [ |
14 | 14 | # To keep things simple, we only support newer versions of Graphene |
15 | 15 | "graphene>=2.1.3,<3", |
16 | "promise>=2.3", | |
16 | 17 | # Tests fail with 1.0.19 |
17 | "SQLAlchemy>=1.1,<2", | |
18 | "SQLAlchemy>=1.2,<2", | |
18 | 19 | "six>=1.10.0,<2", |
19 | 20 | "singledispatch>=3.4.0.3,<4", |
20 | 21 | ] |
28 | 29 | "mock==2.0.0", |
29 | 30 | "pytest-cov==2.6.1", |
30 | 31 | "sqlalchemy_utils==0.33.9", |
32 | "pytest-benchmark==3.2.1", | |
31 | 33 | ] |
32 | 34 | |
33 | 35 | setup( |
46 | 48 | "Programming Language :: Python :: 2", |
47 | 49 | "Programming Language :: Python :: 2.7", |
48 | 50 | "Programming Language :: Python :: 3", |
49 | "Programming Language :: Python :: 3.3", | |
50 | "Programming Language :: Python :: 3.4", | |
51 | 51 | "Programming Language :: Python :: 3.5", |
52 | 52 | "Programming Language :: Python :: 3.6", |
53 | 53 | "Programming Language :: Python :: 3.7", |
59 | 59 | extras_require={ |
60 | 60 | "dev": [ |
61 | 61 | "tox==3.7.0", # Should be kept in sync with tox.ini |
62 | "coveralls==1.7.0", | |
62 | "coveralls==1.10.0", | |
63 | 63 | "pre-commit==1.14.4", |
64 | 64 | ], |
65 | 65 | "test": tests_require, |