diff --git a/.gitignore b/.gitignore index 2c4ca2b..e4070f3 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ nosetests.xml coverage.xml *,cover +.pytest_cache/ # Translations *.mo diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..136f8e7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +default_language_version: + python: python3.7 +repos: +- repo: git://github.com/pre-commit/pre-commit-hooks + rev: c8bad492e1b1d65d9126dba3fe3bd49a5a52b9d6 # v2.1.0 + hooks: + - id: check-merge-conflict + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + exclude: ^docs/.*$ + - id: trailing-whitespace + exclude: README.md +- repo: git://github.com/PyCQA/flake8 + rev: 88caf5ac484f5c09aedc02167c59c66ff0af0068 # 3.7.7 + hooks: + - id: flake8 +- repo: git://github.com/asottile/seed-isort-config + rev: v1.7.0 + hooks: + - id: seed-isort-config +- repo: git://github.com/pre-commit/mirrors-isort + rev: v4.3.4 + hooks: + - id: isort diff --git a/.travis.yml b/.travis.yml index dd80108..39151a5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,42 +1,45 @@ language: python -sudo: false -python: -- 2.7 -- 3.4 -- 3.5 -- 3.6 -before_install: -install: -- | - if [ "$TEST_TYPE" = build ]; then - pip install pytest==3.0.2 pytest-cov pytest-benchmark coveralls six mock sqlalchemy_utils - pip install -e . - python setup.py develop - elif [ "$TEST_TYPE" = lint ]; then - pip install flake8 - fi -script: -- | - if [ "$TEST_TYPE" = lint ]; then - echo "Checking Python code lint." - flake8 graphene_sqlalchemy - exit - elif [ "$TEST_TYPE" = build ]; then - py.test --cov=graphene_sqlalchemy graphene_sqlalchemy examples - fi -after_success: -- | - if [ "$TEST_TYPE" = build ]; then - coveralls - fi -env: - matrix: - - TEST_TYPE=build matrix: - fast_finish: true include: - - python: '2.7' - env: TEST_TYPE=lint + # Python 2.7 + - env: TOXENV=py27 + python: 2.7 + # Python 3.5 + - env: TOXENV=py34 + python: 3.4 + # Python 3.5 + - env: TOXENV=py35 + python: 3.5 + # Python 3.6 + - env: TOXENV=py36 + python: 3.6 + # Python 3.7 + - env: TOXENV=py37 + python: 3.7 + dist: xenial + # SQLAlchemy 1.1 + - env: TOXENV=py37-sql11 + python: 3.7 + dist: xenial + # SQLAlchemy 1.2 + - env: TOXENV=py37-sql12 + python: 3.7 + dist: xenial + # SQLAlchemy 1.3 + - env: TOXENV=py37-sql13 + python: 3.7 + dist: xenial + # Pre-commit + - env: TOXENV=pre-commit + python: 3.7 + dist: xenial +install: pip install .[dev] +script: tox +after_success: coveralls +cache: + directories: + - $HOME/.cache/pip + - $HOME/.cache/pre-commit deploy: provider: pypi user: syrusakbary diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..59ca64d --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +## Local Development + +Set up our development dependencies: + +```sh +pip install -e ".[dev]" +pre-commit install +``` + +We use `tox` to test this library against different versions of `python` and `SQLAlchemy`. +While developping locally, it is usually fine to run the tests against the most recent versions: + +```sh +tox -e py37 # Python 3.7, SQLAlchemy < 2.0 +tox -e py37 -- -v -s # Verbose output +tox -e py37 -- -k test_query # Only test_query.py +``` + +Our linters will run automatically when committing via git hooks but you can also run them manually: + +```sh +tox -e pre-commit +``` + +## Release Process + +1. Update the version number in graphene_sqlalchemy/__init__.py via a PR. + +2. Once the PR is merged, tag the commit on master with the new version (only maintainers of the repo can do this). For example, "v2.1.2". Travis will then automatically build this tag and release it to Pypi. + +3. Make sure to create a new release on github (via the release tab) that lists all the changes that went into the new version. diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..a3c843c --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 Syrus Akbary + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 0ca6ac5..2ba0d1c 100644 --- a/README.md +++ b/README.md @@ -22,14 +22,13 @@ ```python from sqlalchemy import Column, Integer, String -from sqlalchemy.orm import relationship from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() class UserModel(Base): - __tablename__ = 'department' + __tablename__ = 'user' id = Column(Integer, primary_key=True) name = Column(String) last_name = Column(String) @@ -38,11 +37,16 @@ To create a GraphQL schema for it you simply have to write the following: ```python +import graphene from graphene_sqlalchemy import SQLAlchemyObjectType class User(SQLAlchemyObjectType): class Meta: model = UserModel + # only return specified fields + only_fields = ("name",) + # exclude specified fields + exclude_fields = ("last_name",) class Query(graphene.ObjectType): users = graphene.List(User) @@ -98,21 +102,13 @@ schema = graphene.Schema(query=Query) ``` +### Full Examples + To learn more check out the following [examples](examples/): -* **Full example**: [Flask SQLAlchemy example](examples/flask_sqlalchemy) - +- [Flask SQLAlchemy example](examples/flask_sqlalchemy) +- [Nameko SQLAlchemy example](examples/nameko_sqlalchemy) ## Contributing -After cloning this repo, ensure dependencies are installed by running: - -```sh -python setup.py install -``` - -After developing, the full test suite can be evaluated by running: - -```sh -python setup.py test # Use --pytest-args="-v -s" for verbose mode -``` +See [CONTRIBUTING.md](/CONTRIBUTING.md) diff --git a/docs/conf.py b/docs/conf.py index d729246..3fa6391 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -136,7 +136,7 @@ # html_theme = 'alabaster' # if on_rtd: # html_theme = 'sphinx_rtd_theme' -import sphinx_graphene_theme +import sphinx_graphene_theme # isort:skip html_theme = "sphinx_graphene_theme" diff --git a/docs/examples.rst b/docs/examples.rst new file mode 100644 index 0000000..283a0f5 --- /dev/null +++ b/docs/examples.rst @@ -0,0 +1,82 @@ +Schema Examples +=========================== + + +Search all Models with Union +---------------------------- + +.. code:: python + + class Book(SQLAlchemyObjectType): + class Meta: + model = BookModel + interfaces = (relay.Node,) + + + class BookConnection(relay.Connection): + class Meta: + node = Book + + + class Author(SQLAlchemyObjectType): + class Meta: + model = AuthorModel + interfaces = (relay.Node,) + + + class AuthorConnection(relay.Connection): + class Meta: + node = Author + + + class SearchResult(graphene.Union): + class Meta: + types = (Book, Author) + + + class Query(graphene.ObjectType): + node = relay.Node.Field() + search = graphene.List(SearchResult, q=graphene.String()) # List field for search results + + # Normal Fields + all_books = SQLAlchemyConnectionField(BookConnection) + all_authors = SQLAlchemyConnectionField(AuthorConnection) + + def resolve_search(self, info, **args): + q = args.get("q") # Search query + + # Get queries + bookdata_query = BookData.get_query(info) + author_query = Author.get_query(info) + + # Query Books + books = bookdata_query.filter((BookModel.title.contains(q)) | + (BookModel.isbn.contains(q)) | + (BookModel.authors.any(AuthorModel.name.contains(q)))).all() + + # Query Authors + authors = author_query.filter(AuthorModel.name.contains(q)).all() + + return authors + books # Combine lists + + schema = graphene.Schema(query=Query, types=[Book, Author, SearchResult]) + +Example GraphQL query + +.. code:: + + book(id: "Qm9vazow") { + id + title + } + search(q: "Making Games") { + __typename + ... on Author { + fname + lname + } + ... on Book { + title + isbn + } + } diff --git a/docs/index.rst b/docs/index.rst index 1bc2234..81b2f31 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,3 +8,4 @@ tutorial tips + examples diff --git a/docs/requirements.txt b/docs/requirements.txt index 5de8cc6..666a8c9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,2 @@ # Docs template -https://github.com/graphql-python/graphene-python.org/archive/docs.zip +http://graphene-python.org/sphinx_graphene_theme.zip diff --git a/docs/tips.rst b/docs/tips.rst index 9c4a98b..1fd3910 100644 --- a/docs/tips.rst +++ b/docs/tips.rst @@ -1,14 +1,11 @@ ==== -Tips -==== - Tips ==== Querying -------- -For make querying to the database work, there are two alternatives: +In order to make querying against the database work, there are two alternatives: - Set the db session when you do the execution: @@ -30,3 +27,61 @@ If you don't specify any, the following error will be displayed: ``A query in the model Base or a session in the schema is required for querying.`` + +Sorting +------- + +By default the SQLAlchemyConnectionField sorts the result elements over the primary key(s). +The query has a `sort` argument which allows to sort over a different column(s) + +Given the model + +.. code:: python + + class Pet(Base): + __tablename__ = 'pets' + id = Column(Integer(), primary_key=True) + name = Column(String(30)) + pet_kind = Column(Enum('cat', 'dog', name='pet_kind'), nullable=False) + + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + + + class PetConnection(Connection): + class Meta: + node = PetNode + + + class Query(ObjectType): + allPets = SQLAlchemyConnectionField(PetConnection) + +some of the allowed queries are + +- Sort in ascending order over the `name` column + +.. code:: + + allPets(sort: name_asc){ + edges { + node { + name + } + } + } + +- Sort in descending order over the `per_kind` column and in ascending order over the `name` column + +.. code:: + + allPets(sort: [pet_kind_desc, name_asc]) { + edges { + node { + name + petKind + } + } + } + diff --git a/docs/tutorial.rst b/docs/tutorial.rst index b07eaec..bc5ee62 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -93,7 +93,7 @@ import graphene from graphene import relay from graphene_sqlalchemy import SQLAlchemyObjectType, SQLAlchemyConnectionField - from models import db_session, Department as DepartmentModel, Employee as EmployeeModel + from .models import db_session, Department as DepartmentModel, Employee as EmployeeModel class Department(SQLAlchemyObjectType): @@ -102,15 +102,28 @@ interfaces = (relay.Node, ) + class DepartmentConnection(relay.Connection): + class Meta: + node = Department + + class Employee(SQLAlchemyObjectType): class Meta: model = EmployeeModel interfaces = (relay.Node, ) + class EmployeeConnection(relay.Connection): + class Meta: + node = Employee + + class Query(graphene.ObjectType): node = relay.Node.Field() - all_employees = SQLAlchemyConnectionField(Employee) + # Allows sorting over multiple columns, by default over the primary key + all_employees = SQLAlchemyConnectionField(EmployeeConnection) + # Disable sorting over this field + all_departments = SQLAlchemyConnectionField(DepartmentConnection, sort=None) schema = graphene.Schema(query=Query) @@ -133,8 +146,8 @@ from flask import Flask from flask_graphql import GraphQLView - from models import db_session - from schema import schema, Department + from .models import db_session + from .schema import schema, Department app = Flask(__name__) app.debug = True @@ -162,7 +175,7 @@ $ python - >>> from models import engine, db_session, Base, Department, Employee + >>> from .models import engine, db_session, Base, Department, Employee >>> Base.metadata.create_all(bind=engine) >>> # Fill the tables with some data diff --git a/examples/flask_sqlalchemy/app.py b/examples/flask_sqlalchemy/app.py index 64390aa..a4d3f29 100755 --- a/examples/flask_sqlalchemy/app.py +++ b/examples/flask_sqlalchemy/app.py @@ -2,9 +2,10 @@ from flask import Flask -from database import db_session, init_db from flask_graphql import GraphQLView -from schema import schema + +from .database import db_session, init_db +from .schema import schema app = Flask(__name__) app.debug = True diff --git a/examples/flask_sqlalchemy/database.py b/examples/flask_sqlalchemy/database.py index ca4d412..01e76ca 100644 --- a/examples/flask_sqlalchemy/database.py +++ b/examples/flask_sqlalchemy/database.py @@ -14,7 +14,7 @@ # import all modules here that might define models so that # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() - from models import Department, Employee, Role + from .models import Department, Employee, Role Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) diff --git a/examples/flask_sqlalchemy/models.py b/examples/flask_sqlalchemy/models.py index 119aca0..e164c01 100644 --- a/examples/flask_sqlalchemy/models.py +++ b/examples/flask_sqlalchemy/models.py @@ -1,7 +1,7 @@ from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func from sqlalchemy.orm import backref, relationship -from database import Base +from .database import Base class Department(Base): diff --git a/examples/flask_sqlalchemy/requirements.txt b/examples/flask_sqlalchemy/requirements.txt index f5bcaa8..337ff60 100644 --- a/examples/flask_sqlalchemy/requirements.txt +++ b/examples/flask_sqlalchemy/requirements.txt @@ -1,4 +1,4 @@ graphene[sqlalchemy] SQLAlchemy==1.0.11 -Flask==0.10.1 +Flask==0.12.4 Flask-GraphQL==1.3.0 diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index 5d9d3b7..cbee081 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -1,37 +1,47 @@ import graphene from graphene import relay -from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType -from models import Department as DepartmentModel -from models import Employee as EmployeeModel -from models import Role as RoleModel +from graphene_sqlalchemy import (SQLAlchemyConnectionField, + SQLAlchemyObjectType, utils) + +from .models import Department as DepartmentModel +from .models import Employee as EmployeeModel +from .models import Role as RoleModel class Department(SQLAlchemyObjectType): - class Meta: model = DepartmentModel interfaces = (relay.Node, ) class Employee(SQLAlchemyObjectType): - class Meta: model = EmployeeModel interfaces = (relay.Node, ) class Role(SQLAlchemyObjectType): - class Meta: model = RoleModel interfaces = (relay.Node, ) +SortEnumEmployee = utils.sort_enum_for_model(EmployeeModel, 'SortEnumEmployee', + lambda c, d: c.upper() + ('_ASC' if d else '_DESC')) + + class Query(graphene.ObjectType): node = relay.Node.Field() - all_employees = SQLAlchemyConnectionField(Employee) + # Allow only single column sorting + all_employees = SQLAlchemyConnectionField( + Employee, + sort=graphene.Argument( + SortEnumEmployee, + default_value=utils.EnumValue('id_asc', EmployeeModel.id.asc()))) + # Allows sorting over multiple columns, by default over the primary key all_roles = SQLAlchemyConnectionField(Role) - role = graphene.Field(Role) + # Disable sorting over this field + all_departments = SQLAlchemyConnectionField(Department, sort=None) schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) diff --git a/examples/nameko_sqlalchemy/README.md b/examples/nameko_sqlalchemy/README.md new file mode 100644 index 0000000..39cfe92 --- /dev/null +++ b/examples/nameko_sqlalchemy/README.md @@ -0,0 +1,54 @@ +Example Nameko+Graphene-SQLAlchemy Project +================================ + +This example is for those who are not using frameworks like Flask | Django which already have a View wrapper implemented to handle graphql request and response accordingly + +If you need a [graphiql](https://github.com/graphql/graphiql) interface on your application, kindly look at [flask_sqlalchemy](../flask_sqlalchemy). + +Using [nameko](https://github.com/nameko/nameko) as an example, but you can get rid of `service.py` + +The project contains two models, one named `Department` and another +named `Employee`. + +Getting started +--------------- + +First you'll need to get the source of the project. Do this by cloning the +whole Graphene repository: + +```bash +# Get the example project code +git clone https://github.com/graphql-python/graphene-sqlalchemy.git +cd graphene-sqlalchemy/examples/nameko_sqlalchemy +``` + +It is good idea (but not required) to create a virtual environment +for this project. We'll do this using +[virtualenv](http://docs.python-guide.org/en/latest/dev/virtualenvs/) +to keep things simple, +but you may also find something like +[virtualenvwrapper](https://virtualenvwrapper.readthedocs.org/en/latest/) +to be useful: + +```bash +# Create a virtualenv in which we can install the dependencies +virtualenv env +source env/bin/activate +``` + +Now we can install our dependencies: + +```bash +pip install -r requirements.txt +``` + +Now the following command will setup the database, and start the server: + +```bash +./run.sh + +``` + +Now head on over to postman and send POST request to: +[http://127.0.0.1:5000/graphql](http://127.0.0.1:5000/graphql) +and run some queries! diff --git a/examples/nameko_sqlalchemy/__init__.py b/examples/nameko_sqlalchemy/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/nameko_sqlalchemy/app.py b/examples/nameko_sqlalchemy/app.py new file mode 100755 index 0000000..42a40a0 --- /dev/null +++ b/examples/nameko_sqlalchemy/app.py @@ -0,0 +1,37 @@ +from graphql_server import (HttpQueryError, default_format_error, + encode_execution_results, json_encode, + load_json_body, run_http_query) + +from .database import db_session, init_db +from .schema import schema + + +class App(): + def __init__(self): + init_db() + + def query(self, request): + data = self.parse_body(request) + execution_results, params = run_http_query( + schema, + 'post', + data) + result, status_code = encode_execution_results( + execution_results, + format_error=default_format_error,is_batch=False, encode=json_encode) + return result + + def parse_body(self,request): + # We use mimetype here since we don't need the other + # information provided by content_type + content_type = request.mimetype + if content_type == 'application/graphql': + return {'query': request.data.decode('utf8')} + + elif content_type == 'application/json': + return load_json_body(request.data.decode('utf8')) + + elif content_type in ('application/x-www-form-urlencoded', 'multipart/form-data'): + return request.form + + return {} diff --git a/examples/nameko_sqlalchemy/config.yml b/examples/nameko_sqlalchemy/config.yml new file mode 100644 index 0000000..8ca6a45 --- /dev/null +++ b/examples/nameko_sqlalchemy/config.yml @@ -0,0 +1 @@ +WEB_SERVER_ADDRESS: '0.0.0.0:5000' diff --git a/examples/nameko_sqlalchemy/database.py b/examples/nameko_sqlalchemy/database.py new file mode 100644 index 0000000..01e76ca --- /dev/null +++ b/examples/nameko_sqlalchemy/database.py @@ -0,0 +1,38 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import scoped_session, sessionmaker + +engine = create_engine('sqlite:///database.sqlite3', convert_unicode=True) +db_session = scoped_session(sessionmaker(autocommit=False, + autoflush=False, + bind=engine)) +Base = declarative_base() +Base.query = db_session.query_property() + + +def init_db(): + # import all modules here that might define models so that + # they will be registered properly on the metadata. Otherwise + # you will have to import them first before calling init_db() + from .models import Department, Employee, Role + Base.metadata.drop_all(bind=engine) + Base.metadata.create_all(bind=engine) + + # Create the fixtures + engineering = Department(name='Engineering') + db_session.add(engineering) + hr = Department(name='Human Resources') + db_session.add(hr) + + manager = Role(name='manager') + db_session.add(manager) + engineer = Role(name='engineer') + db_session.add(engineer) + + peter = Employee(name='Peter', department=engineering, role=engineer) + db_session.add(peter) + roy = Employee(name='Roy', department=engineering, role=engineer) + db_session.add(roy) + tracy = Employee(name='Tracy', department=hr, role=manager) + db_session.add(tracy) + db_session.commit() diff --git a/examples/nameko_sqlalchemy/models.py b/examples/nameko_sqlalchemy/models.py new file mode 100644 index 0000000..e164c01 --- /dev/null +++ b/examples/nameko_sqlalchemy/models.py @@ -0,0 +1,39 @@ +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, func +from sqlalchemy.orm import backref, relationship + +from .database import Base + + +class Department(Base): + __tablename__ = 'department' + id = Column(Integer, primary_key=True) + name = Column(String) + + +class Role(Base): + __tablename__ = 'roles' + role_id = Column(Integer, primary_key=True) + name = Column(String) + + +class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + name = Column(String) + # Use default=func.now() to set the default hiring time + # of an Employee to be the current time when an + # Employee record was created + hired_on = Column(DateTime, default=func.now()) + department_id = Column(Integer, ForeignKey('department.id')) + role_id = Column(Integer, ForeignKey('roles.role_id')) + # Use cascade='delete,all' to propagate the deletion of a Department onto its Employees + department = relationship( + Department, + backref=backref('employees', + uselist=True, + cascade='delete,all')) + role = relationship( + Role, + backref=backref('roles', + uselist=True, + cascade='delete,all')) diff --git a/examples/nameko_sqlalchemy/requirements.txt b/examples/nameko_sqlalchemy/requirements.txt new file mode 100644 index 0000000..be037f7 --- /dev/null +++ b/examples/nameko_sqlalchemy/requirements.txt @@ -0,0 +1,4 @@ +graphene[sqlalchemy] +SQLAlchemy==1.0.11 +nameko +graphql-server-core diff --git a/examples/nameko_sqlalchemy/run.sh b/examples/nameko_sqlalchemy/run.sh new file mode 100755 index 0000000..bfe17d9 --- /dev/null +++ b/examples/nameko_sqlalchemy/run.sh @@ -0,0 +1,4 @@ +#!/bin/sh +echo "Starting application service server" +# Run Service +nameko run --config config.yml service diff --git a/examples/nameko_sqlalchemy/schema.py b/examples/nameko_sqlalchemy/schema.py new file mode 100644 index 0000000..fa74735 --- /dev/null +++ b/examples/nameko_sqlalchemy/schema.py @@ -0,0 +1,38 @@ +import graphene +from graphene import relay +from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType + +from .models import Department as DepartmentModel +from .models import Employee as EmployeeModel +from .models import Role as RoleModel + + +class Department(SQLAlchemyObjectType): + + class Meta: + model = DepartmentModel + interfaces = (relay.Node, ) + + +class Employee(SQLAlchemyObjectType): + + class Meta: + model = EmployeeModel + interfaces = (relay.Node, ) + + +class Role(SQLAlchemyObjectType): + + class Meta: + model = RoleModel + interfaces = (relay.Node, ) + + +class Query(graphene.ObjectType): + node = relay.Node.Field() + all_employees = SQLAlchemyConnectionField(Employee) + all_roles = SQLAlchemyConnectionField(Role) + role = graphene.Field(Role) + + +schema = graphene.Schema(query=Query, types=[Department, Employee, Role]) diff --git a/examples/nameko_sqlalchemy/service.py b/examples/nameko_sqlalchemy/service.py new file mode 100644 index 0000000..9815750 --- /dev/null +++ b/examples/nameko_sqlalchemy/service.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +from nameko.web.handlers import http + +from .app import App + + +class DepartmentService: + name = 'department' + + @http('POST', '/graphql') + def query(self, request): + return App().query(request) diff --git a/graphene_sqlalchemy/__init__.py b/graphene_sqlalchemy/__init__.py index 768b6d6..d328304 100644 --- a/graphene_sqlalchemy/__init__.py +++ b/graphene_sqlalchemy/__init__.py @@ -1,20 +1,13 @@ -from .types import ( - SQLAlchemyObjectType, -) -from .fields import ( - SQLAlchemyConnectionField -) -from .utils import ( - get_query, - get_session -) +from .types import SQLAlchemyObjectType +from .fields import SQLAlchemyConnectionField +from .utils import get_query, get_session -__version__ = '2.0.0' +__version__ = "2.1.2" __all__ = [ - '__version__', - 'SQLAlchemyObjectType', - 'SQLAlchemyConnectionField', - 'get_query', - 'get_session' + "__version__", + "SQLAlchemyObjectType", + "SQLAlchemyConnectionField", + "get_query", + "get_session", ] diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 2543fc8..7cc259e 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -7,24 +7,21 @@ String) from graphene.types.json import JSONString -from .fields import createConnectionField - try: - from sqlalchemy_utils import ( - ChoiceType, JSONType, ScalarListType, TSVectorType) + from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType except ImportError: ChoiceType = JSONType = ScalarListType = TSVectorType = object def get_column_doc(column): - return getattr(column, 'doc', None) + return getattr(column, "doc", None) def is_column_nullable(column): - return bool(getattr(column, 'nullable', True)) + return bool(getattr(column, "nullable", True)) -def convert_sqlalchemy_relationship(relationship, registry): +def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory): direction = relationship.direction model = relationship.mapper.entity @@ -36,15 +33,14 @@ return Field(_type) elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): if _type._meta.connection: - return createConnectionField(_type) + return connection_field_factory(relationship, registry) return Field(List(_type)) return Dynamic(dynamic_type) def convert_sqlalchemy_hybrid_method(hybrid_item): - return String(description=getattr(hybrid_item, '__doc__', None), - required=False) + return String(description=getattr(hybrid_item, "__doc__", None), required=False) def convert_sqlalchemy_composite(composite, registry): @@ -52,23 +48,27 @@ if not converter: try: raise Exception( - "Don't know how to convert the composite field %s (%s)" % - (composite, composite.composite_class)) + "Don't know how to convert the composite field %s (%s)" + % (composite, composite.composite_class) + ) except AttributeError: # handle fields that are not attached to a class yet (don't have a parent) raise Exception( - "Don't know how to convert the composite field %r (%s)" % - (composite, composite.composite_class)) + "Don't know how to convert the composite field %r (%s)" + % (composite, composite.composite_class) + ) return converter(composite, registry) def _register_composite_class(cls, registry=None): if registry is None: from .registry import get_global_registry + registry = get_global_registry() def inner(fn): registry.register_composite_converter(cls, fn) + return inner @@ -76,13 +76,15 @@ def convert_sqlalchemy_column(column, registry=None): - return convert_sqlalchemy_type(getattr(column, 'type', None), column, registry) + return convert_sqlalchemy_type(getattr(column, "type", None), column, registry) @singledispatch def convert_sqlalchemy_type(type, column, registry=None): raise Exception( - "Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__)) + "Don't know how to convert the SQLAlchemy field %s (%s)" + % (column, column.__class__) + ) @convert_sqlalchemy_type.register(types.Date) @@ -91,47 +93,74 @@ @convert_sqlalchemy_type.register(types.Text) @convert_sqlalchemy_type.register(types.Unicode) @convert_sqlalchemy_type.register(types.UnicodeText) -@convert_sqlalchemy_type.register(types.Enum) -@convert_sqlalchemy_type.register(postgresql.ENUM) @convert_sqlalchemy_type.register(postgresql.UUID) +@convert_sqlalchemy_type.register(postgresql.INET) +@convert_sqlalchemy_type.register(postgresql.CIDR) @convert_sqlalchemy_type.register(TSVectorType) def convert_column_to_string(type, column, registry=None): - return String(description=get_column_doc(column), - required=not(is_column_nullable(column))) + return String( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) @convert_sqlalchemy_type.register(types.DateTime) def convert_column_to_datetime(type, column, registry=None): from graphene.types.datetime import DateTime - return DateTime(description=get_column_doc(column), - required=not(is_column_nullable(column))) + + return DateTime( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) @convert_sqlalchemy_type.register(types.SmallInteger) @convert_sqlalchemy_type.register(types.Integer) def convert_column_to_int_or_id(type, column, registry=None): if column.primary_key: - return ID(description=get_column_doc(column), required=not (is_column_nullable(column))) + return ID( + description=get_column_doc(column), + required=not (is_column_nullable(column)), + ) else: - return Int(description=get_column_doc(column), - required=not (is_column_nullable(column))) + return Int( + description=get_column_doc(column), + required=not (is_column_nullable(column)), + ) @convert_sqlalchemy_type.register(types.Boolean) def convert_column_to_boolean(type, column, registry=None): - return Boolean(description=get_column_doc(column), required=not(is_column_nullable(column))) + return Boolean( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) @convert_sqlalchemy_type.register(types.Float) @convert_sqlalchemy_type.register(types.Numeric) @convert_sqlalchemy_type.register(types.BigInteger) def convert_column_to_float(type, column, registry=None): - return Float(description=get_column_doc(column), required=not(is_column_nullable(column))) + return Float( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) + + +@convert_sqlalchemy_type.register(types.Enum) +def convert_enum_to_enum(type, column, registry=None): + enum_class = getattr(type, 'enum_class', None) + if enum_class: # Check if an enum.Enum type is used + graphene_type = Enum.from_enum(enum_class) + else: # Nope, just a list of string options + items = zip(type.enums, type.enums) + graphene_type = Enum(type.name, items) + return Field( + graphene_type, + description=get_column_doc(column), + required=not (is_column_nullable(column)), + ) @convert_sqlalchemy_type.register(ChoiceType) def convert_column_to_enum(type, column, registry=None): - name = '{}_{}'.format(column.table.name, column.name).upper() + name = "{}_{}".format(column.table.name, column.name).upper() return Enum(name, type.choices, description=get_column_doc(column)) @@ -144,16 +173,24 @@ def convert_postgres_array_to_list(_type, column, registry=None): graphene_type = convert_sqlalchemy_type(column.type.item_type, column) inner_type = type(graphene_type) - return List(inner_type, description=get_column_doc(column), required=not(is_column_nullable(column))) + return List( + inner_type, + description=get_column_doc(column), + required=not (is_column_nullable(column)), + ) @convert_sqlalchemy_type.register(postgresql.HSTORE) @convert_sqlalchemy_type.register(postgresql.JSON) @convert_sqlalchemy_type.register(postgresql.JSONB) def convert_json_to_string(type, column, registry=None): - return JSONString(description=get_column_doc(column), required=not(is_column_nullable(column))) + return JSONString( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) @convert_sqlalchemy_type.register(JSONType) def convert_json_type_to_string(type, column, registry=None): - return JSONString(description=get_column_doc(column), required=not(is_column_nullable(column))) + return JSONString( + description=get_column_doc(column), required=not (is_column_nullable(column)) + ) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index bb084b3..4a46b74 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,73 +1,134 @@ +import logging from functools import partial +from promise import Promise, is_thenable from sqlalchemy.orm.query import Query -from graphene.relay import ConnectionField +from graphene.relay import Connection, ConnectionField from graphene.relay.connection import PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice -from .utils import get_query +from .utils import get_query, sort_argument_for_model + +log = logging.getLogger() -class SQLAlchemyConnectionField(ConnectionField): +class UnsortedSQLAlchemyConnectionField(ConnectionField): + @property + def type(self): + from .types import SQLAlchemyObjectType + + _type = super(ConnectionField, self).type + if issubclass(_type, Connection): + return _type + assert issubclass(_type, SQLAlchemyObjectType), ( + "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" + ).format(_type.__name__) + assert _type._meta.connection, "The type {} doesn't have a connection".format( + _type.__name__ + ) + return _type._meta.connection @property def model(self): return self.type._meta.node._meta.model @classmethod - def get_query(cls, model, info, **args): - return get_query(model, info.context) - - @property - def type(self): - from .types import SQLAlchemyObjectType - _type = super(ConnectionField, self).type - assert issubclass(_type, SQLAlchemyObjectType), ( - "SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types" - ) - assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__) - return _type._meta.connection + def get_query(cls, model, info, sort=None, **args): + query = get_query(model, info.context) + if sort is not None: + if isinstance(sort, str): + query = query.order_by(sort.value) + else: + query = query.order_by(*(col.value for col in sort)) + return query @classmethod - def connection_resolver(cls, resolver, connection, model, root, info, **args): - iterable = resolver(root, info, **args) - if iterable is None: - iterable = cls.get_query(model, info, **args) - if isinstance(iterable, Query): - _len = iterable.count() + def resolve_connection(cls, connection_type, model, info, args, resolved): + if resolved is None: + resolved = cls.get_query(model, info, **args) + if isinstance(resolved, Query): + _len = resolved.count() else: - _len = len(iterable) + _len = len(resolved) connection = connection_from_list_slice( - iterable, + resolved, args, slice_start=0, list_length=_len, list_slice_length=_len, - connection_type=connection, + connection_type=connection_type, pageinfo_type=PageInfo, - edge_type=connection.Edge, + edge_type=connection_type.Edge, ) - connection.iterable = iterable + connection.iterable = resolved connection.length = _len return connection + + @classmethod + def connection_resolver(cls, resolver, connection_type, model, root, info, **args): + resolved = resolver(root, info, **args) + + on_resolve = partial(cls.resolve_connection, connection_type, model, info, args) + if is_thenable(resolved): + return Promise.resolve(resolved).then(on_resolve) + + return on_resolve(resolved) def get_resolver(self, parent_resolver): return partial(self.connection_resolver, parent_resolver, self.type, self.model) -__connectionFactory = SQLAlchemyConnectionField +class SQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField): + def __init__(self, type, *args, **kwargs): + if "sort" not in kwargs and issubclass(type, Connection): + # Let super class raise if type is not a Connection + try: + model = type.Edge.node._type._meta.model + kwargs.setdefault("sort", sort_argument_for_model(model)) + except Exception: + raise Exception( + 'Cannot create sort argument for {}. A model is required. Set the "sort" argument' + " to None to disabling the creation of the sort query argument".format( + type.__name__ + ) + ) + elif "sort" in kwargs and kwargs["sort"] is None: + del kwargs["sort"] + super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs) + + +def default_connection_field_factory(relationship, registry): + model = relationship.mapper.entity + model_type = registry.get_type_for_model(model) + return createConnectionField(model_type) + + +# TODO Remove in next major version +__connectionFactory = UnsortedSQLAlchemyConnectionField def createConnectionField(_type): + log.warn( + 'createConnectionField is deprecated and will be removed in the next ' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + ) return __connectionFactory(_type) def registerConnectionFieldFactory(factoryMethod): + log.warn( + 'registerConnectionFieldFactory is deprecated and will be removed in the next ' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + ) global __connectionFactory __connectionFactory = factoryMethod def unregisterConnectionFieldFactory(): + log.warn( + 'registerConnectionFieldFactory is deprecated and will be removed in the next ' + 'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.' + ) global __connectionFactory - __connectionFactory = SQLAlchemyConnectionField + __connectionFactory = UnsortedSQLAlchemyConnectionField diff --git a/graphene_sqlalchemy/registry.py b/graphene_sqlalchemy/registry.py index 61285cb..460053f 100644 --- a/graphene_sqlalchemy/registry.py +++ b/graphene_sqlalchemy/registry.py @@ -1,5 +1,4 @@ class Registry(object): - def __init__(self): self._registry = {} self._registry_models = {} @@ -7,11 +6,12 @@ def register(self, cls): from .types import SQLAlchemyObjectType + assert issubclass(cls, SQLAlchemyObjectType), ( - 'Only classes of type SQLAlchemyObjectType can be registered, ', + "Only classes of type SQLAlchemyObjectType can be registered, " 'received "{}"' ).format(cls.__name__) - assert cls._meta.registry == self, 'Registry for a Model have to match.' + assert cls._meta.registry == self, "Registry for a Model have to match." # assert self.get_type_for_model(cls._meta.model) in [None, cls], ( # 'SQLAlchemy model "{}" already associated with ' # 'another type "{}".' diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index 3f27bc4..3ba23a8 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -1,37 +1,50 @@ from __future__ import absolute_import -from sqlalchemy import Column, Date, ForeignKey, Integer, String, Table +import enum + +from sqlalchemy import Column, Date, Enum, ForeignKey, Integer, String, Table from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import mapper, relationship + +class Hairkind(enum.Enum): + LONG = 'long' + SHORT = 'short' + + Base = declarative_base() -association_table = Table('association', Base.metadata, - Column('pet_id', Integer, ForeignKey('pets.id')), - Column('reporter_id', Integer, ForeignKey('reporters.id'))) +association_table = Table( + "association", + Base.metadata, + Column("pet_id", Integer, ForeignKey("pets.id")), + Column("reporter_id", Integer, ForeignKey("reporters.id")), +) class Editor(Base): - __tablename__ = 'editors' + __tablename__ = "editors" editor_id = Column(Integer(), primary_key=True) name = Column(String(100)) class Pet(Base): - __tablename__ = 'pets' + __tablename__ = "pets" id = Column(Integer(), primary_key=True) name = Column(String(30)) - reporter_id = Column(Integer(), ForeignKey('reporters.id')) + pet_kind = Column(Enum("cat", "dog", name="pet_kind"), nullable=False) + hair_kind = Column(Enum(Hairkind, name="hair_kind"), nullable=False) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) class Reporter(Base): - __tablename__ = 'reporters' + __tablename__ = "reporters" id = Column(Integer(), primary_key=True) first_name = Column(String(30)) last_name = Column(String(30)) email = Column(String()) - pets = relationship('Pet', secondary=association_table, backref='reporters') - articles = relationship('Article', backref='reporter') + pets = relationship("Pet", secondary=association_table, backref="reporters") + articles = relationship("Article", backref="reporter") favorite_article = relationship("Article", uselist=False) # total = column_property( @@ -42,19 +55,21 @@ class Article(Base): - __tablename__ = 'articles' + __tablename__ = "articles" id = Column(Integer(), primary_key=True) headline = Column(String(100)) pub_date = Column(Date()) - reporter_id = Column(Integer(), ForeignKey('reporters.id')) + reporter_id = Column(Integer(), ForeignKey("reporters.id")) class ReflectedEditor(type): """Same as Editor, but using reflected table.""" + @classmethod def __subclasses__(cls): return [] -editor_table = Table('editors', Base.metadata, autoload=True) + +editor_table = Table("editors", Base.metadata, autoload=True) mapper(ReflectedEditor, editor_table) diff --git a/graphene_sqlalchemy/tests/test_connectionfactory.py b/graphene_sqlalchemy/tests/test_connectionfactory.py deleted file mode 100644 index 867c526..0000000 --- a/graphene_sqlalchemy/tests/test_connectionfactory.py +++ /dev/null @@ -1,28 +0,0 @@ -from graphene_sqlalchemy.fields import SQLAlchemyConnectionField, registerConnectionFieldFactory, unregisterConnectionFieldFactory -import graphene - -def test_register(): - class LXConnectionField(SQLAlchemyConnectionField): - @classmethod - def _applyQueryArgs(cls, model, q, args): - return q - - @classmethod - def connection_resolver(cls, resolver, connection, model, root, args, context, info): - - def LXResolver(root, args, context, info): - iterable = resolver(root, args, context, info) - if iterable is None: - iterable = cls.get_query(model, context, info, args) - - # We accept always a query here. All LX-queries can be filtered and sorted - iterable = cls._applyQueryArgs(model, iterable, args) - return iterable - - return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info) - - def createLXConnectionField(table): - return LXConnectionField(table, filter=table.filter(), order_by=graphene.List(of_type=table.order_by)) - - registerConnectionFieldFactory(createLXConnectionField) - unregisterConnectionFieldFactory() diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index 3c732b2..5cc16e7 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -1,8 +1,10 @@ +import enum + from py.test import raises -from sqlalchemy import Column, Table, case, types, select, func +from sqlalchemy import Column, Table, case, func, select, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import composite, column_property +from sqlalchemy.orm import column_property, composite from sqlalchemy.sql.elements import Label from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType @@ -14,25 +16,32 @@ from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, convert_sqlalchemy_relationship) -from ..fields import SQLAlchemyConnectionField +from ..fields import (UnsortedSQLAlchemyConnectionField, + default_connection_field_factory) from ..registry import Registry from ..types import SQLAlchemyObjectType from .models import Article, Pet, Reporter def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs): - column = Column(sqlalchemy_type, doc='Custom Help Text', **kwargs) + column = Column(sqlalchemy_type, doc="Custom Help Text", **kwargs) graphene_type = convert_sqlalchemy_column(column) assert isinstance(graphene_type, graphene_field) - field = graphene_type.Field() - assert field.description == 'Custom Help Text' + field = ( + graphene_type + if isinstance(graphene_type, graphene.Field) + else graphene_type.Field() + ) + assert field.description == "Custom Help Text" return field -def assert_composite_conversion(composite_class, composite_columns, graphene_field, - registry, **kwargs): - composite_column = composite(composite_class, *composite_columns, - doc='Custom Help Text', **kwargs) +def assert_composite_conversion( + composite_class, composite_columns, graphene_field, registry, **kwargs +): + composite_column = composite( + composite_class, *composite_columns, doc="Custom Help Text", **kwargs + ) graphene_type = convert_sqlalchemy_composite(composite_column, registry) assert isinstance(graphene_type, graphene_field) field = graphene_type.Field() @@ -45,7 +54,7 @@ def test_should_unknown_sqlalchemy_field_raise_exception(): with raises(Exception) as excinfo: convert_sqlalchemy_column(None) - assert 'Don\'t know how to convert the SQLAlchemy field' in str(excinfo.value) + assert "Don't know how to convert the SQLAlchemy field" in str(excinfo.value) def test_should_date_convert_string(): @@ -76,8 +85,20 @@ assert_column_conversion(types.UnicodeText(), graphene.String) -def test_should_enum_convert_string(): - assert_column_conversion(types.Enum(), graphene.String) +def test_should_enum_convert_enum(): + field = assert_column_conversion( + types.Enum(enum.Enum("one", "two")), graphene.Field + ) + field_type = field.type() + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "two") + field = assert_column_conversion( + types.Enum("one", "two", name="two_numbers"), graphene.Field + ) + field_type = field.type() + assert field_type.__class__.__name__ == "two_numbers" + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "two") def test_should_small_integer_convert_int(): @@ -109,31 +130,29 @@ def test_should_label_convert_string(): - label = Label('label_test', case([], else_="foo"), type_=types.Unicode()) + label = Label("label_test", case([], else_="foo"), type_=types.Unicode()) graphene_type = convert_sqlalchemy_column(label) assert isinstance(graphene_type, graphene.String) def test_should_label_convert_int(): - label = Label('int_label_test', case([], else_="foo"), type_=types.Integer()) + label = Label("int_label_test", case([], else_="foo"), type_=types.Integer()) graphene_type = convert_sqlalchemy_column(label) assert isinstance(graphene_type, graphene.Int) + def test_should_choice_convert_enum(): - TYPES = [ - (u'es', u'Spanish'), - (u'en', u'English') - ] - column = Column(ChoiceType(TYPES), doc='Language', name='language') + TYPES = [(u"es", u"Spanish"), (u"en", u"English")] + column = Column(ChoiceType(TYPES), doc="Language", name="language") Base = declarative_base() - Table('translatedmodel', Base.metadata, column) + Table("translatedmodel", Base.metadata, column) graphene_type = convert_sqlalchemy_column(column) assert issubclass(graphene_type, graphene.Enum) - assert graphene_type._meta.name == 'TRANSLATEDMODEL_LANGUAGE' - assert graphene_type._meta.description == 'Language' - assert graphene_type._meta.enum.__members__['es'].value == 'Spanish' - assert graphene_type._meta.enum.__members__['en'].value == 'English' + assert graphene_type._meta.name == "TRANSLATEDMODEL_LANGUAGE" + assert graphene_type._meta.description == "Language" + assert graphene_type._meta.enum.__members__["es"].value == "Spanish" + assert graphene_type._meta.enum.__members__["en"].value == "English" def test_should_columproperty_convert(): @@ -141,16 +160,14 @@ Base = declarative_base() class Test(Base): - __tablename__ = 'test' + __tablename__ = "test" id = Column(types.Integer, primary_key=True) column = column_property( - select([func.sum(func.cast(id, types.Integer))]).where( - id==1 - ) + select([func.sum(func.cast(id, types.Integer))]).where(id == 1) ) graphene_type = convert_sqlalchemy_column(Test.column) - assert graphene_type.kwargs['required'] == False + assert not graphene_type.kwargs["required"] def test_should_scalar_list_convert_list(): @@ -163,18 +180,21 @@ def test_should_manytomany_convert_connectionorlist(): registry = Registry() - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry) + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, registry, default_connection_field_factory + ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() def test_should_manytomany_convert_connectionorlist_list(): class A(SQLAlchemyObjectType): - class Meta: model = Pet - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, A._meta.registry, default_connection_field_factory + ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -184,30 +204,34 @@ def test_should_manytomany_convert_connectionorlist_connection(): class A(SQLAlchemyObjectType): - class Meta: model = Pet - interfaces = (Node, ) - - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) - assert isinstance(dynamic_field, graphene.Dynamic) - assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField) + interfaces = (Node,) + + dynamic_field = convert_sqlalchemy_relationship( + Reporter.pets.property, A._meta.registry, default_connection_field_factory + ) + assert isinstance(dynamic_field, graphene.Dynamic) + assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField) def test_should_manytoone_convert_connectionorlist(): registry = Registry() - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry) + dynamic_field = convert_sqlalchemy_relationship( + Article.reporter.property, registry, default_connection_field_factory + ) assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() def test_should_manytoone_convert_connectionorlist_list(): class A(SQLAlchemyObjectType): - class Meta: model = Reporter - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship( + Article.reporter.property, A._meta.registry, default_connection_field_factory + ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -216,12 +240,13 @@ def test_should_manytoone_convert_connectionorlist_connection(): class A(SQLAlchemyObjectType): - class Meta: model = Reporter - interfaces = (Node, ) - - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + interfaces = (Node,) + + dynamic_field = convert_sqlalchemy_relationship( + Article.reporter.property, A._meta.registry, default_connection_field_factory + ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -230,12 +255,13 @@ def test_should_onetoone_convert_field(): class A(SQLAlchemyObjectType): - class Meta: model = Article - interfaces = (Node, ) - - dynamic_field = convert_sqlalchemy_relationship(Reporter.favorite_article.property, A._meta.registry) + interfaces = (Node,) + + dynamic_field = convert_sqlalchemy_relationship( + Reporter.favorite_article.property, A._meta.registry, default_connection_field_factory + ) assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -247,7 +273,23 @@ def test_should_postgresql_enum_convert(): - assert_column_conversion(postgresql.ENUM(), graphene.String) + field = assert_column_conversion( + postgresql.ENUM("one", "two", name="two_numbers"), graphene.Field + ) + field_type = field.type() + assert field_type.__class__.__name__ == "two_numbers" + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "two") + + +def test_should_postgresql_py_enum_convert(): + field = assert_column_conversion( + postgresql.ENUM(enum.Enum("TwoNumbers", "one two"), name="two_numbers"), graphene.Field + ) + field_type = field.type() + assert field_type.__class__.__name__ == "TwoNumbers" + assert isinstance(field_type, graphene.Enum) + assert hasattr(field_type, "two") def test_should_postgresql_array_convert(): @@ -267,9 +309,7 @@ def test_should_composite_convert(): - class CompositeClass(object): - def __init__(self, col1, col2): self.col1 = col1 self.col2 = col2 @@ -280,11 +320,12 @@ def convert_composite_class(composite, registry): return graphene.String(description=composite.doc) - assert_composite_conversion(CompositeClass, - (Column(types.Unicode(50)), - Column(types.Unicode(50))), - graphene.String, - registry) + assert_composite_conversion( + CompositeClass, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + graphene.String, + registry, + ) def test_should_unknown_sqlalchemy_composite_raise_exception(): @@ -293,15 +334,15 @@ with raises(Exception) as excinfo: class CompositeClass(object): - def __init__(self, col1, col2): self.col1 = col1 self.col2 = col2 - assert_composite_conversion(CompositeClass, - (Column(types.Unicode(50)), - Column(types.Unicode(50))), - graphene.String, - registry) - - assert 'Don\'t know how to convert the composite field' in str(excinfo.value) + assert_composite_conversion( + CompositeClass, + (Column(types.Unicode(50)), Column(types.Unicode(50))), + graphene.String, + registry, + ) + + assert "Don't know how to convert the composite field" in str(excinfo.value) diff --git a/graphene_sqlalchemy/tests/test_fields.py b/graphene_sqlalchemy/tests/test_fields.py new file mode 100644 index 0000000..ff616b3 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_fields.py @@ -0,0 +1,40 @@ +import pytest + +from graphene.relay import Connection + +from ..fields import SQLAlchemyConnectionField +from ..types import SQLAlchemyObjectType +from ..utils import sort_argument_for_model +from .models import Editor +from .models import Pet as PetModel + + +class Pet(SQLAlchemyObjectType): + class Meta: + model = PetModel + + +class PetConn(Connection): + class Meta: + node = Pet + + +def test_sort_added_by_default(): + arg = SQLAlchemyConnectionField(PetConn) + assert "sort" in arg.args + assert arg.args["sort"] == sort_argument_for_model(PetModel) + + +def test_sort_can_be_removed(): + arg = SQLAlchemyConnectionField(PetConn, sort=None) + assert "sort" not in arg.args + + +def test_custom_sort(): + arg = SQLAlchemyConnectionField(PetConn, sort=sort_argument_for_model(Editor)) + assert arg.args["sort"] == sort_argument_for_model(Editor) + + +def test_init_raises(): + with pytest.raises(Exception, match="Cannot create sort"): + SQLAlchemyConnectionField(Connection) diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index e4c3f83..146c54e 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -3,17 +3,18 @@ from sqlalchemy.orm import scoped_session, sessionmaker import graphene -from graphene.relay import Node - +from graphene.relay import Connection, Node + +from ..fields import SQLAlchemyConnectionField from ..registry import reset_global_registry -from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType -from .models import Article, Base, Editor, Reporter - -db = create_engine('sqlite:///test_sqlalchemy.sqlite3') - - -@pytest.yield_fixture(scope='function') +from ..utils import sort_argument_for_model, sort_enum_for_model +from .models import Article, Base, Editor, Hairkind, Pet, Reporter + +db = create_engine("sqlite:///test_sqlalchemy.sqlite3") + + +@pytest.yield_fixture(scope="function") def session(): reset_global_registry() connection = db.engine.connect() @@ -33,11 +34,13 @@ def setup_fixtures(session): - reporter = Reporter(first_name='ABA', last_name='X') + pet = Pet(name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG) + session.add(pet) + reporter = Reporter(first_name="ABA", last_name="X") session.add(reporter) - reporter2 = Reporter(first_name='ABO', last_name='Y') + reporter2 = Reporter(first_name="ABO", last_name="Y") session.add(reporter2) - article = Article(headline='Hi!') + article = Article(headline="Hi!") article.reporter = reporter session.add(article) editor = Editor(name="John") @@ -49,7 +52,6 @@ setup_fixtures(session) class ReporterType(SQLAlchemyObjectType): - class Meta: model = Reporter @@ -63,7 +65,7 @@ def resolve_reporters(self, *args, **kwargs): return session.query(Reporter) - query = ''' + query = """ query ReporterQuery { reporter { firstName, @@ -74,18 +76,10 @@ firstName } } - ''' + """ expected = { - 'reporter': { - 'firstName': 'ABA', - 'lastName': 'X', - 'email': None - }, - 'reporters': [{ - 'firstName': 'ABA', - }, { - 'firstName': 'ABO', - }] + "reporter": {"firstName": "ABA", "lastName": "X", "email": None}, + "reporters": [{"firstName": "ABA"}, {"firstName": "ABO"}], } schema = graphene.Schema(query=Query) result = schema.execute(query) @@ -93,34 +87,136 @@ assert result.data == expected +def test_should_query_enums(session): + setup_fixtures(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType) + + def resolve_pet(self, *args, **kwargs): + return session.query(Pet).first() + + query = """ + query PetQuery { + pet { + name, + petKind + hairKind + } + } + """ + expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} + schema = graphene.Schema(query=Query) + result = schema.execute(query) + assert not result.errors + assert result.data == expected, result.data + + +def test_enum_parameter(session): + setup_fixtures(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['pet_kind'].type.of_type)) + + def resolve_pet(self, info, kind=None, *args, **kwargs): + query = session.query(Pet) + if kind: + query = query.filter(Pet.pet_kind == kind) + return query.first() + + query = """ + query PetQuery($kind: pet_kind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "cat"}) + assert not result.errors + assert result.data == {"pet": None} + result = schema.execute(query, variables={"kind": "dog"}) + assert not result.errors + assert result.data == expected, result.data + + +def test_py_enum_parameter(session): + setup_fixtures(session) + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + + class Query(graphene.ObjectType): + pet = graphene.Field(PetType, kind=graphene.Argument(PetType._meta.fields['hair_kind'].type.of_type)) + + def resolve_pet(self, info, kind=None, *args, **kwargs): + query = session.query(Pet) + if kind: + # XXX Why kind passed in as a str instead of a Hairkind instance? + query = query.filter(Pet.hair_kind == Hairkind(kind)) + return query.first() + + query = """ + query PetQuery($kind: Hairkind) { + pet(kind: $kind) { + name, + petKind + hairKind + } + } + """ + expected = {"pet": {"name": "Lassie", "petKind": "dog", "hairKind": "LONG"}} + schema = graphene.Schema(query=Query) + result = schema.execute(query, variables={"kind": "SHORT"}) + assert not result.errors + assert result.data == {"pet": None} + result = schema.execute(query, variables={"kind": "LONG"}) + assert not result.errors + assert result.data == expected, result.data + + def test_should_node(session): setup_fixtures(session) class ReporterNode(SQLAlchemyObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) @classmethod - def get_node(cls, id, info): - return Reporter(id=2, first_name='Cookie Monster') + def get_node(cls, info, id): + return Reporter(id=2, first_name="Cookie Monster") class ArticleNode(SQLAlchemyObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) # @classmethod # def get_node(cls, id, info): # return Article(id=1, headline='Article node') + class ArticleConnection(Connection): + class Meta: + node = ArticleNode + class Query(graphene.ObjectType): node = Node.Field() reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) - all_articles = SQLAlchemyConnectionField(ArticleNode) + all_articles = SQLAlchemyConnectionField(ArticleConnection) def resolve_reporter(self, *args, **kwargs): return session.query(Reporter).first() @@ -128,7 +224,7 @@ def resolve_article(self, *args, **kwargs): return session.query(Article).first() - query = ''' + query = """ query ReporterQuery { reporter { id, @@ -160,35 +256,20 @@ } } } - ''' + """ expected = { - 'reporter': { - 'id': 'UmVwb3J0ZXJOb2RlOjE=', - 'firstName': 'ABA', - 'lastName': 'X', - 'email': None, - 'articles': { - 'edges': [{ - 'node': { - 'headline': 'Hi!' - } - }] - }, + "reporter": { + "id": "UmVwb3J0ZXJOb2RlOjE=", + "firstName": "ABA", + "lastName": "X", + "email": None, + "articles": {"edges": [{"node": {"headline": "Hi!"}}]}, }, - 'allArticles': { - 'edges': [{ - 'node': { - 'headline': 'Hi!' - } - }] - }, - 'myArticle': { - 'id': 'QXJ0aWNsZU5vZGU6MQ==', - 'headline': 'Hi!' - } + "allArticles": {"edges": [{"node": {"headline": "Hi!"}}]}, + "myArticle": {"id": "QXJ0aWNsZU5vZGU6MQ==", "headline": "Hi!"}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors assert result.data == expected @@ -197,16 +278,19 @@ setup_fixtures(session) class EditorNode(SQLAlchemyObjectType): - class Meta: model = Editor - interfaces = (Node, ) + interfaces = (Node,) + + class EditorConnection(Connection): + class Meta: + node = EditorNode class Query(graphene.ObjectType): node = Node.Field() - all_editors = SQLAlchemyConnectionField(EditorNode) - - query = ''' + all_editors = SQLAlchemyConnectionField(EditorConnection) + + query = """ query EditorQuery { allEditors { edges { @@ -222,23 +306,14 @@ } } } - ''' + """ expected = { - 'allEditors': { - 'edges': [{ - 'node': { - 'id': 'RWRpdG9yTm9kZTox', - 'name': 'John' - } - }] - }, - 'node': { - 'name': 'John' - } + "allEditors": {"edges": [{"node": {"id": "RWRpdG9yTm9kZTox", "name": "John"}}]}, + "node": {"name": "John"}, } schema = graphene.Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors assert result.data == expected @@ -247,29 +322,25 @@ setup_fixtures(session) class EditorNode(SQLAlchemyObjectType): - class Meta: model = Editor - interfaces = (Node, ) + interfaces = (Node,) class ReporterNode(SQLAlchemyObjectType): - class Meta: model = Reporter - interfaces = (Node, ) + interfaces = (Node,) @classmethod def get_node(cls, id, info): - return Reporter(id=2, first_name='Cookie Monster') + return Reporter(id=2, first_name="Cookie Monster") class ArticleNode(SQLAlchemyObjectType): - class Meta: model = Article - interfaces = (Node, ) + interfaces = (Node,) class CreateArticle(graphene.Mutation): - class Arguments: headline = graphene.String() reporter_id = graphene.ID() @@ -278,10 +349,7 @@ article = graphene.Field(ArticleNode) def mutate(self, info, headline, reporter_id): - new_article = Article( - headline=headline, - reporter_id=reporter_id, - ) + new_article = Article(headline=headline, reporter_id=reporter_id) session.add(new_article) session.commit() @@ -295,7 +363,7 @@ class Mutation(graphene.ObjectType): create_article = CreateArticle.Field() - query = ''' + query = """ mutation ArticleCreator { createArticle( headline: "My Article" @@ -311,21 +379,179 @@ } } } - ''' + """ expected = { - 'createArticle': { - 'ok': True, - 'article': { - 'headline': 'My Article', - 'reporter': { - 'id': 'UmVwb3J0ZXJOb2RlOjE=', - 'firstName': 'ABA' - } - } - }, + "createArticle": { + "ok": True, + "article": { + "headline": "My Article", + "reporter": {"id": "UmVwb3J0ZXJOb2RlOjE=", "firstName": "ABA"}, + }, + } } schema = graphene.Schema(query=Query, mutation=Mutation) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors assert result.data == expected + + +def sort_setup(session): + pets = [ + Pet(id=2, name="Lassie", pet_kind="dog", hair_kind=Hairkind.LONG), + Pet(id=22, name="Alf", pet_kind="cat", hair_kind=Hairkind.LONG), + Pet(id=3, name="Barf", pet_kind="dog", hair_kind=Hairkind.LONG), + ] + session.add_all(pets) + session.commit() + + +def test_sort(session): + sort_setup(session) + + class PetNode(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node,) + + class PetConnection(Connection): + class Meta: + node = PetNode + + class Query(graphene.ObjectType): + defaultSort = SQLAlchemyConnectionField(PetConnection) + nameSort = SQLAlchemyConnectionField(PetConnection) + multipleSort = SQLAlchemyConnectionField(PetConnection) + descSort = SQLAlchemyConnectionField(PetConnection) + singleColumnSort = SQLAlchemyConnectionField( + PetConnection, sort=graphene.Argument(sort_enum_for_model(Pet)) + ) + noDefaultSort = SQLAlchemyConnectionField( + PetConnection, sort=sort_argument_for_model(Pet, False) + ) + noSort = SQLAlchemyConnectionField(PetConnection, sort=None) + + query = """ + query sortTest { + defaultSort{ + edges{ + node{ + id + } + } + } + nameSort(sort: name_asc){ + edges{ + node{ + name + } + } + } + multipleSort(sort: [pet_kind_asc, name_desc]){ + edges{ + node{ + name + petKind + } + } + } + descSort(sort: [name_desc]){ + edges{ + node{ + name + } + } + } + singleColumnSort(sort: name_desc){ + edges{ + node{ + name + } + } + } + noDefaultSort(sort: name_asc){ + edges{ + node{ + name + } + } + } + } + """ + + def makeNodes(nodeList): + nodes = [{"node": item} for item in nodeList] + return {"edges": nodes} + + expected = { + "defaultSort": makeNodes( + [{"id": "UGV0Tm9kZToy"}, {"id": "UGV0Tm9kZToz"}, {"id": "UGV0Tm9kZToyMg=="}] + ), + "nameSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), + "noDefaultSort": makeNodes( + [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] + ), + "multipleSort": makeNodes( + [ + {"name": "Alf", "petKind": "cat"}, + {"name": "Lassie", "petKind": "dog"}, + {"name": "Barf", "petKind": "dog"}, + ] + ), + "descSort": makeNodes([{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}]), + "singleColumnSort": makeNodes( + [{"name": "Lassie"}, {"name": "Barf"}, {"name": "Alf"}] + ), + } # yapf: disable + + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={"session": session}) + assert not result.errors + assert result.data == expected + + queryError = """ + query sortTest { + singleColumnSort(sort: [pet_kind_asc, name_desc]){ + edges{ + node{ + name + } + } + } + } + """ + result = schema.execute(queryError, context_value={"session": session}) + assert result.errors is not None + + queryNoSort = """ + query sortTest { + noDefaultSort{ + edges{ + node{ + name + } + } + } + noSort{ + edges{ + node{ + name + } + } + } + } + """ + + expectedNoSort = { + "noDefaultSort": makeNodes( + [{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}] + ), + "noSort": makeNodes([{"name": "Alf"}, {"name": "Barf"}, {"name": "Lassie"}]), + } # yapf: disable + + result = schema.execute(queryNoSort, context_value={"session": session}) + assert not result.errors + for key, value in result.data.items(): + assert set(node["node"]["name"] for node in value["edges"]) == set( + node["node"]["name"] for node in expectedNoSort[key]["edges"] + ) diff --git a/graphene_sqlalchemy/tests/test_reflected.py b/graphene_sqlalchemy/tests/test_reflected.py index 2ea3d26..46e10de 100644 --- a/graphene_sqlalchemy/tests/test_reflected.py +++ b/graphene_sqlalchemy/tests/test_reflected.py @@ -9,7 +9,6 @@ class Reflected(SQLAlchemyObjectType): - class Meta: model = ReflectedEditor registry = registry @@ -18,7 +17,4 @@ def test_objecttype_registered(): assert issubclass(Reflected, ObjectType) assert Reflected._meta.model == ReflectedEditor - assert list( - Reflected._meta.fields.keys()) == ['editor_id', 'name'] - - + assert list(Reflected._meta.fields.keys()) == ["editor_id", "name"] diff --git a/graphene_sqlalchemy/tests/test_registry.py b/graphene_sqlalchemy/tests/test_registry.py new file mode 100644 index 0000000..1945af6 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_registry.py @@ -0,0 +1,33 @@ +import pytest + +from ..registry import Registry +from ..types import SQLAlchemyObjectType +from .models import Pet + + +def test_register_incorrect_objecttype(): + reg = Registry() + + class Spam: + pass + + with pytest.raises(AssertionError) as excinfo: + reg.register(Spam) + + assert "Only classes of type SQLAlchemyObjectType can be registered" in str( + excinfo.value + ) + + +def test_register_objecttype(): + reg = Registry() + + class PetType(SQLAlchemyObjectType): + class Meta: + model = Pet + registry = reg + + try: + reg.register(PetType) + except AssertionError: + pytest.fail("expected no AssertionError") diff --git a/graphene_sqlalchemy/tests/test_schema.py b/graphene_sqlalchemy/tests/test_schema.py index 058c915..628da18 100644 --- a/graphene_sqlalchemy/tests/test_schema.py +++ b/graphene_sqlalchemy/tests/test_schema.py @@ -7,42 +7,43 @@ def test_should_raise_if_no_model(): with raises(Exception) as excinfo: + class Character1(SQLAlchemyObjectType): pass - assert 'valid SQLAlchemy Model' in str(excinfo.value) + + assert "valid SQLAlchemy Model" in str(excinfo.value) def test_should_raise_if_model_is_invalid(): with raises(Exception) as excinfo: + class Character2(SQLAlchemyObjectType): - class Meta: model = 1 - assert 'valid SQLAlchemy Model' in str(excinfo.value) + + assert "valid SQLAlchemy Model" in str(excinfo.value) def test_should_map_fields_correctly(): class ReporterType2(SQLAlchemyObjectType): - class Meta: model = Reporter registry = Registry() - assert list( - ReporterType2._meta.fields.keys()) == [ - 'id', - 'first_name', - 'last_name', - 'email', - 'pets', - 'articles', - 'favorite_article'] + assert list(ReporterType2._meta.fields.keys()) == [ + "id", + "first_name", + "last_name", + "email", + "pets", + "articles", + "favorite_article", + ] def test_should_map_only_few_fields(): class Reporter2(SQLAlchemyObjectType): - class Meta: model = Reporter - only_fields = ('id', 'email') - assert list(Reporter2._meta.fields.keys()) == ['id', 'email'] + only_fields = ("id", "email") + assert list(Reporter2._meta.fields.keys()) == ["id", "email"] diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 3f017aa..0360a64 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,32 +1,40 @@ - -from graphene import Field, Int, Interface, ObjectType -from graphene.relay import Node, is_node -import six - +from collections import OrderedDict + +import six # noqa F401 +from promise import Promise + +from graphene import (Connection, Field, Int, Interface, Node, ObjectType, + is_node) + +from ..fields import (SQLAlchemyConnectionField, + UnsortedSQLAlchemyConnectionField, + registerConnectionFieldFactory, + unregisterConnectionFieldFactory) from ..registry import Registry -from ..types import SQLAlchemyObjectType +from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions from .models import Article, Reporter registry = Registry() class Character(SQLAlchemyObjectType): - '''Character description''' + """Character description""" + class Meta: model = Reporter registry = registry class Human(SQLAlchemyObjectType): - '''Human description''' + """Human description""" pub_date = Int() class Meta: model = Article - exclude_fields = ('id', ) + exclude_fields = ("id",) registry = registry - interfaces = (Node, ) + interfaces = (Node,) def test_sqlalchemy_interface(): @@ -44,15 +52,15 @@ def test_objecttype_registered(): assert issubclass(Character, ObjectType) assert Character._meta.model == Reporter - assert list( - Character._meta.fields.keys()) == [ - 'id', - 'first_name', - 'last_name', - 'email', - 'pets', - 'articles', - 'favorite_article'] + assert list(Character._meta.fields.keys()) == [ + "id", + "first_name", + "last_name", + "email", + "pets", + "articles", + "favorite_article", + ] # def test_sqlalchemynode_idfield(): @@ -66,16 +74,14 @@ def test_node_replacedfield(): - idfield = Human._meta.fields['pub_date'] + idfield = Human._meta.fields["pub_date"] assert isinstance(idfield, Field) assert idfield.type == Int def test_object_type(): - - class Human(SQLAlchemyObjectType): - '''Human description''' + """Human description""" pub_date = Int() @@ -83,12 +89,17 @@ model = Article # exclude_fields = ('id', ) registry = registry - interfaces = (Node, ) + interfaces = (Node,) assert issubclass(Human, ObjectType) - assert list(Human._meta.fields.keys()) == ['id', 'headline', 'pub_date', 'reporter_id', 'reporter'] + assert list(Human._meta.fields.keys()) == [ + "id", + "headline", + "pub_date", + "reporter_id", + "reporter", + ] assert is_node(Human) - # Test Custom SQLAlchemyObjectType Implementation @@ -98,7 +109,8 @@ class CustomCharacter(CustomSQLAlchemyObjectType): - '''Character description''' + """Character description""" + class Meta: model = Reporter registry = registry @@ -107,12 +119,161 @@ def test_custom_objecttype_registered(): assert issubclass(CustomCharacter, ObjectType) assert CustomCharacter._meta.model == Reporter - assert list( - CustomCharacter._meta.fields.keys()) == [ - 'id', - 'first_name', - 'last_name', - 'email', - 'pets', - 'articles', - 'favorite_article'] + assert list(CustomCharacter._meta.fields.keys()) == [ + "id", + "first_name", + "last_name", + "email", + "pets", + "articles", + "favorite_article", + ] + + +# Test Custom SQLAlchemyObjectType with Custom Options +class CustomOptions(SQLAlchemyObjectTypeOptions): + custom_option = None + custom_fields = None + + +class SQLAlchemyObjectTypeWithCustomOptions(SQLAlchemyObjectType): + class Meta: + abstract = True + + @classmethod + def __init_subclass_with_meta__( + cls, custom_option=None, custom_fields=None, **options + ): + _meta = CustomOptions(cls) + _meta.custom_option = custom_option + _meta.fields = custom_fields + super(SQLAlchemyObjectTypeWithCustomOptions, cls).__init_subclass_with_meta__( + _meta=_meta, **options + ) + + +class ReporterWithCustomOptions(SQLAlchemyObjectTypeWithCustomOptions): + class Meta: + model = Reporter + custom_option = "custom_option" + custom_fields = OrderedDict([("custom_field", Field(Int()))]) + + +def test_objecttype_with_custom_options(): + assert issubclass(ReporterWithCustomOptions, ObjectType) + assert ReporterWithCustomOptions._meta.model == Reporter + assert list(ReporterWithCustomOptions._meta.fields.keys()) == [ + "custom_field", + "id", + "first_name", + "last_name", + "email", + "pets", + "articles", + "favorite_article", + ] + assert ReporterWithCustomOptions._meta.custom_option == "custom_option" + assert isinstance(ReporterWithCustomOptions._meta.fields["custom_field"].type, Int) + + +def test_promise_connection_resolver(): + class TestConnection(Connection): + class Meta: + node = ReporterWithCustomOptions + + def resolver(*args, **kwargs): + return Promise.resolve([]) + + result = SQLAlchemyConnectionField.connection_resolver( + resolver, TestConnection, ReporterWithCustomOptions, None, None + ) + assert result is not None + + +# Tests for connection_field_factory + +class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField): + pass + + +def test_default_connection_field_factory(): + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField) + + +def test_register_connection_field_factory(): + def test_connection_field_factory(relationship, registry): + model = relationship.mapper.entity + _type = registry.get_type_for_model(model) + return _TestSQLAlchemyConnectionField(_type._meta.connection) + + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + connection_field_factory = test_connection_field_factory + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + + +def test_deprecated_registerConnectionFieldFactory(): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) + + +def test_deprecated_unregisterConnectionFieldFactory(): + registerConnectionFieldFactory(_TestSQLAlchemyConnectionField) + unregisterConnectionFieldFactory() + + _registry = Registry() + + class ReporterType(SQLAlchemyObjectType): + class Meta: + model = Reporter + registry = _registry + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + registry = _registry + interfaces = (Node,) + + assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField) diff --git a/graphene_sqlalchemy/tests/test_utils.py b/graphene_sqlalchemy/tests/test_utils.py index 8af3c61..a7b902f 100644 --- a/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene_sqlalchemy/tests/test_utils.py @@ -1,10 +1,13 @@ -from graphene import ObjectType, Schema, String +import sqlalchemy as sa -from ..utils import get_session +from graphene import Enum, List, ObjectType, Schema, String + +from ..utils import get_session, sort_argument_for_model, sort_enum_for_model +from .models import Editor, Pet def test_get_session(): - session = 'My SQLAlchemy session' + session = "My SQLAlchemy session" class Query(ObjectType): x = String() @@ -12,13 +15,62 @@ def resolve_x(self, info): return get_session(info.context) - query = ''' + query = """ query ReporterQuery { x } - ''' + """ schema = Schema(query=Query) - result = schema.execute(query, context_value={'session': session}) + result = schema.execute(query, context_value={"session": session}) assert not result.errors - assert result.data['x'] == session + assert result.data["x"] == session + + +def test_sort_enum_for_model(): + enum = sort_enum_for_model(Pet) + assert isinstance(enum, type(Enum)) + assert str(enum) == "PetSortEnum" + for col in sa.inspect(Pet).columns: + assert hasattr(enum, col.name + "_asc") + assert hasattr(enum, col.name + "_desc") + + +def test_sort_enum_for_model_custom_naming(): + enum = sort_enum_for_model(Pet, "Foo", lambda n, d: n.upper() + ("A" if d else "D")) + assert str(enum) == "Foo" + for col in sa.inspect(Pet).columns: + assert hasattr(enum, col.name.upper() + "A") + assert hasattr(enum, col.name.upper() + "D") + + +def test_enum_cache(): + assert sort_enum_for_model(Editor) is sort_enum_for_model(Editor) + + +def test_sort_argument_for_model(): + arg = sort_argument_for_model(Pet) + + assert isinstance(arg.type, List) + assert arg.default_value == [Pet.id.name + "_asc"] + assert arg.type.of_type == sort_enum_for_model(Pet) + + +def test_sort_argument_for_model_no_default(): + arg = sort_argument_for_model(Pet, False) + + assert arg.default_value is None + + +def test_sort_argument_for_model_multiple_pk(): + Base = sa.ext.declarative.declarative_base() + + class MultiplePK(Base): + foo = sa.Column(sa.Integer, primary_key=True) + bar = sa.Column(sa.Integer, primary_key=True) + __tablename__ = "MultiplePK" + + arg = sort_argument_for_model(MultiplePK) + assert set(arg.default_value) == set( + (MultiplePK.foo.name + "_asc", MultiplePK.bar.name + "_asc") + ) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 04d1a8a..394d506 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,7 +1,8 @@ from collections import OrderedDict +import sqlalchemy +from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect as sqlalchemyinspect -from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm.exc import NoResultFound from graphene import Field # , annotate, ResolveInfo @@ -11,13 +12,14 @@ from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, - convert_sqlalchemy_relationship, - convert_sqlalchemy_hybrid_method) + convert_sqlalchemy_hybrid_method, + convert_sqlalchemy_relationship) +from .fields import default_connection_field_factory from .registry import Registry, get_global_registry from .utils import get_query, is_mapped_class, is_mapped_instance -def construct_fields(model, registry, only_fields, exclude_fields): +def construct_fields(model, registry, only_fields, exclude_fields, connection_field_factory): inspected_model = sqlalchemyinspect(model) fields = OrderedDict() @@ -58,9 +60,7 @@ # in there. Or when we exclude this field in exclude_fields continue - converted_hybrid_property = convert_sqlalchemy_hybrid_method( - hybrid_item - ) + converted_hybrid_property = convert_sqlalchemy_hybrid_method(hybrid_item) fields[name] = converted_hybrid_property # Get all the columns for the relationships on the model @@ -72,7 +72,7 @@ # We skip this field if we specify only_fields and is not # in there. Or when we exclude this field in exclude_fields continue - converted_relationship = convert_sqlalchemy_relationship(relationship, registry) + converted_relationship = convert_sqlalchemy_relationship(relationship, registry, connection_field_factory) name = relationship.key fields[name] = converted_relationship @@ -80,55 +80,89 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): - model = None # type: Model - registry = None # type: Registry - connection = None # type: Type[Connection] + model = None # type: sqlalchemy.Model + registry = None # type: sqlalchemy.Registry + connection = None # type: sqlalchemy.Type[sqlalchemy.Connection] id = None # type: str class SQLAlchemyObjectType(ObjectType): @classmethod - def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False, - only_fields=(), exclude_fields=(), connection=None, - use_connection=None, interfaces=(), id=None, **options): + def __init_subclass_with_meta__( + cls, + model=None, + registry=None, + skip_registry=False, + only_fields=(), + exclude_fields=(), + connection=None, + connection_class=None, + use_connection=None, + interfaces=(), + id=None, + connection_field_factory=default_connection_field_factory, + _meta=None, + **options + ): assert is_mapped_class(model), ( - 'You need to pass a valid SQLAlchemy Model in ' - '{}.Meta, received "{}".' + "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".' ).format(cls.__name__, model) if not registry: registry = get_global_registry() assert isinstance(registry, Registry), ( - 'The attribute registry in {} needs to be an instance of ' + "The attribute registry in {} needs to be an instance of " 'Registry, received "{}".' ).format(cls.__name__, registry) sqla_fields = yank_fields_from_attrs( - construct_fields(model, registry, only_fields, exclude_fields), - _as=Field, + construct_fields( + model=model, + registry=registry, + only_fields=only_fields, + exclude_fields=exclude_fields, + connection_field_factory=connection_field_factory + ), + _as=Field ) if use_connection is None and interfaces: - use_connection = any((issubclass(interface, Node) for interface in interfaces)) + use_connection = any( + (issubclass(interface, Node) for interface in interfaces) + ) if use_connection and not connection: # We create the connection automatically - connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls) + if not connection_class: + connection_class = Connection + + connection = connection_class.create_type( + "{}Connection".format(cls.__name__), node=cls + ) if connection is not None: assert issubclass(connection, Connection), ( "The connection must be a Connection. Received {}" ).format(connection.__name__) - _meta = SQLAlchemyObjectTypeOptions(cls) + if not _meta: + _meta = SQLAlchemyObjectTypeOptions(cls) + _meta.model = model _meta.registry = registry - _meta.fields = sqla_fields + + if _meta.fields: + _meta.fields.update(sqla_fields) + else: + _meta.fields = sqla_fields + _meta.connection = connection - _meta.id = id or 'id' + _meta.id = id or "id" - super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options) + super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__( + _meta=_meta, interfaces=interfaces, **options + ) if not skip_registry: registry.register(cls) @@ -138,9 +172,7 @@ if isinstance(root, cls): return True if not is_mapped_instance(root): - raise Exception(( - 'Received incompatible instance "{}".' - ).format(root)) + raise Exception(('Received incompatible instance "{}".').format(root)) return isinstance(root, cls._meta.model) @classmethod diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index e78c980..276a807 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -1,19 +1,24 @@ from sqlalchemy.exc import ArgumentError +from sqlalchemy.inspection import inspect from sqlalchemy.orm import class_mapper, object_mapper from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError +from graphene import Argument, Enum, List + def get_session(context): - return context.get('session') + return context.get("session") def get_query(model, context): - query = getattr(model, 'query', None) + query = getattr(model, "query", None) if not query: session = get_session(context) if not session: - raise Exception('A query in the model Base or a session in the schema is required for querying.\n' - 'Read more http://graphene-python.org/docs/sqlalchemy/tips/#querying') + raise Exception( + "A query in the model Base or a session in the schema is required for querying.\n" + "Read more http://docs.graphene-python.org/projects/sqlalchemy/en/latest/tips/#querying" + ) query = session.query(model) return query @@ -34,3 +39,72 @@ return False else: return True + + +def _symbol_name(column_name, is_asc): + return column_name + ("_asc" if is_asc else "_desc") + + +class EnumValue(str): + """Subclass of str that stores a string and an arbitrary value in the "value" property""" + + def __new__(cls, str_value, value): + return super(EnumValue, cls).__new__(cls, str_value) + + def __init__(self, str_value, value): + super(EnumValue, self).__init__() + self.value = value + + +# Cache for the generated enums, to avoid name clash +_ENUM_CACHE = {} + + +def _sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): + name = name or cls.__name__ + "SortEnum" + if name in _ENUM_CACHE: + return _ENUM_CACHE[name] + items = [] + default = [] + for column in inspect(cls).columns.values(): + asc_name = symbol_name(column.name, True) + asc_value = EnumValue(asc_name, column.asc()) + desc_name = symbol_name(column.name, False) + desc_value = EnumValue(desc_name, column.desc()) + if column.primary_key: + default.append(asc_value) + items.extend(((asc_name, asc_value), (desc_name, desc_value))) + enum = Enum(name, items) + _ENUM_CACHE[name] = (enum, default) + return enum, default + + +def sort_enum_for_model(cls, name=None, symbol_name=_symbol_name): + """Create Graphene Enum for sorting a SQLAlchemy class query + + Parameters + - cls : Sqlalchemy model class + Model used to create the sort enumerator + - name : str, optional, default None + Name to use for the enumerator. If not provided it will be set to `cls.__name__ + 'SortEnum'` + - symbol_name : function, optional, default `_symbol_name` + Function which takes the column name and a boolean indicating if the sort direction is ascending, + and returns the symbol name for the current column and sort direction. + The default function will create, for a column named 'foo', the symbols 'foo_asc' and 'foo_desc' + + Returns + - Enum + The Graphene enumerator + """ + enum, _ = _sort_enum_for_model(cls, name, symbol_name) + return enum + + +def sort_argument_for_model(cls, has_default=True): + """Returns a Graphene argument for the sort field that accepts a list of sorting directions for a model. + If `has_default` is True (the default) it will sort the result by the primary key(s) + """ + enum, default = _sort_enum_for_model(cls) + if not has_default: + default = None + return Argument(List(enum), default_value=default) diff --git a/setup.cfg b/setup.cfg index d8d54e3..7fd23df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,29 +1,16 @@ +[aliases] +test=pytest + [flake8] exclude = setup.py,docs/*,examples/*,tests max-line-length = 120 -[coverage:run] -omit = */tests/* - [isort] -known_first_party=graphene,graphene_sqlalchemy - -[tool:pytest] -testpaths = graphene_sqlalchemy/ -addopts = - -s - ; --cov graphene-sqlalchemy -norecursedirs = - __pycache__ - *.egg-info - .cache - .git - .tox - appdir - docs -filterwarnings = - error - ignore::DeprecationWarning +known_graphene=graphene,graphql_relay,flask_graphql,graphql_server,sphinx_graphene_theme +known_first_party=graphene_sqlalchemy +known_third_party=flask,nameko,promise,py,pytest,setuptools,singledispatch,six,sqlalchemy,sqlalchemy_utils +sections=FUTURE,STDLIB,THIRDPARTY,GRAPHENE,FIRSTPARTY,LOCALFOLDER +no_lines_before=FIRSTPARTY [bdist_wheel] universal=1 diff --git a/setup.py b/setup.py index f2e7baf..66704b2 100644 --- a/setup.py +++ b/setup.py @@ -1,56 +1,69 @@ -from setuptools import find_packages, setup -import sys import ast import re +import sys -_version_re = re.compile(r'__version__\s+=\s+(.*)') +from setuptools import find_packages, setup -with open('graphene_sqlalchemy/__init__.py', 'rb') as f: - version = str(ast.literal_eval(_version_re.search( - f.read().decode('utf-8')).group(1))) +_version_re = re.compile(r"__version__\s+=\s+(.*)") +with open("graphene_sqlalchemy/__init__.py", "rb") as f: + version = str( + ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)) + ) + +requirements = [ + # To keep things simple, we only support newer versions of Graphene + "graphene>=2.1.3,<3", + # Tests fail with 1.0.19 + "SQLAlchemy>=1.1,<2", + "six>=1.10.0,<2", + "singledispatch>=3.4.0.3,<4", +] +try: + import enum +except ImportError: # Python < 2.7 and Python 3.3 + requirements.append("enum34 >= 1.1.6") + +tests_require = [ + "pytest==4.3.1", + "mock==2.0.0", + "pytest-cov==2.6.1", + "sqlalchemy_utils==0.33.9", +] setup( - name='graphene-sqlalchemy', + name="graphene-sqlalchemy", version=version, - - description='Graphene SQLAlchemy integration', - long_description=open('README.rst').read(), - - url='https://github.com/graphql-python/graphene-sqlalchemy', - - author='Syrus Akbary', - author_email='me@syrusakbary.com', - - license='MIT', - + description="Graphene SQLAlchemy integration", + long_description=open("README.rst").read(), + url="https://github.com/graphql-python/graphene-sqlalchemy", + author="Syrus Akbary", + author_email="me@syrusakbary.com", + license="MIT", classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: Implementation :: PyPy', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: Implementation :: PyPy", ], - - keywords='api graphql protocol rest relay graphene', - - packages=find_packages(exclude=['tests']), - - install_requires=[ - 'six>=1.10.0', - 'graphene>=2.0', - 'SQLAlchemy', - 'singledispatch>=3.4.0.3', - 'iso8601', - ], - tests_require=[ - 'pytest>=2.7.2', - 'mock', - 'sqlalchemy_utils', - ], + keywords="api graphql protocol rest relay graphene", + packages=find_packages(exclude=["tests"]), + install_requires=requirements, + extras_require={ + "dev": [ + "tox==3.7.0", # Should be kept in sync with tox.ini + "coveralls==1.7.0", + "pre-commit==1.14.4", + ], + "test": tests_require, + }, + tests_require=tests_require, ) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..e55f7d9 --- /dev/null +++ b/tox.ini @@ -0,0 +1,20 @@ +[tox] +envlist = pre-commit,py{27,34,35,36,37}-sql{11,12,13} +skipsdist = true +minversion = 3.7.0 + +[testenv] +deps = + .[test] + sql11: sqlalchemy>=1.1,<1.2 + sql12: sqlalchemy>=1.2,<1.3 + sql13: sqlalchemy>=1.3,<1.4 +commands = + pytest graphene_sqlalchemy --cov=graphene_sqlalchemy {posargs} + +[testenv:pre-commit] +basepython=python3.7 +deps = + .[dev] +commands = + pre-commit {posargs:run --all-files}