diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 310c851..88d565e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/asottile/pyupgrade - rev: v2.7.3 + rev: v2.11.0 hooks: - id: pyupgrade args: ["--py36-plus"] @@ -9,19 +9,20 @@ hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 + rev: 3.9.0 hooks: - id: flake8 - additional_dependencies: [flake8-bugbear==20.1.0] + additional_dependencies: [flake8-bugbear==21.4.3] - repo: https://github.com/asottile/blacken-docs - rev: v1.8.0 + rev: v1.10.0 hooks: - id: blacken-docs additional_dependencies: [black==20.8b1] args: ["--target-version", "py35"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.790 + rev: v0.812 hooks: - id: mypy language_version: python3 files: ^src/webargs/ + additional_dependencies: ["marshmallow>=3,<4"] diff --git a/AUTHORS.rst b/AUTHORS.rst index 2e15b92..3a904d0 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -51,3 +51,4 @@ * Lefteris Karapetsas `@lefterisjp `_ * Utku Gultopu `@ugultopu `_ * Jason Williams `@jaswilli `_ +* Grey Li `@greyli `_ diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4549beb..c51bf8f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,5 +1,31 @@ Changelog --------- + +8.0.0 (2021-04-08) +****************** + +Features: + +* Add `Parser.pre_load` as a method for allowing users to modify data before + schema loading, but without redefining location loaders. See advanced docs on + `Parser pre_load` for usage information + +* ``unknown`` defaults to `None` for body locations (`json`, `form` and + `json_or_form`) (:issue:`580`). + +* Detection of fields as "multi-value" for unpacking lists from multi-dict + types is now extensible with the ``is_multiple`` attribute. If a field sets + ``is_multiple = True`` it will be detected as a multi-value field. + (:issue:`563`) + +* If ``is_multiple`` is not set or is set to ``None``, webargs will check if the + field is an instance of ``List`` or ``Tuple``. + +* A new attribute on ``Parser`` objects, ``Parser.KNOWN_MULTI_FIELDS`` can be + used to set fields which should be detected as ``is_multiple=True`` even when + the attribute is not set. + +See docs on "Multi-Field Detection" for more details. 7.0.1 (2020-12-14) ****************** diff --git a/debian/changelog b/debian/changelog index f89c49b..733d732 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +python-webargs (8.0.0-0kali1) UNRELEASED; urgency=low + + * New upstream release. + + -- Kali Janitor Tue, 27 Jul 2021 09:36:50 -0000 + python-webargs (7.0.1-0kali1) kali-dev; urgency=medium [ Kali Janitor ] diff --git a/docs/advanced.rst b/docs/advanced.rst index 264f56a..853fd64 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -110,7 +110,7 @@ @use_args(UserSchema()) def profile_view(args): - username = args["userame"] + username = args["username"] # ... @@ -152,8 +152,8 @@ +++++++++++++++++ By default, webargs will pass `unknown=marshmallow.EXCLUDE` except when the -location is `json`, `form`, `json_or_form`, `path`, or `path`. In those cases, -it uses `unknown=marshmallow.RAISE` instead. +location is `json`, `form`, `json_or_form`, or `path`. In those cases, it uses +`unknown=marshmallow.RAISE` instead. You can change these defaults by overriding `DEFAULT_UNKNOWN_BY_LOCATION`. This is a mapping of locations to values to pass. @@ -180,7 +180,7 @@ # so EXCLUDE will be used @app.route("/", methods=["GET"]) @parser.use_args({"foo": fields.Int()}, location="query") - def get(self, args): + def get(args): return f"foo x 2 = {args['foo'] * 2}" @@ -188,7 +188,7 @@ # so no value will be passed for `unknown` @app.route("/", methods=["POST"]) @parser.use_args({"foo": fields.Int(), "bar": fields.Int()}, location="json") - def post(self, args): + def post(args): return f"foo x bar = {args['foo'] * args['bar']}" @@ -205,7 +205,7 @@ # effect and `INCLUDE` will always be used @app.route("/", methods=["POST"]) @parser.use_args({"foo": fields.Int(), "bar": fields.Int()}, location="json") - def post(self, args): + def post(args): unexpected_args = [k for k in args.keys() if k not in ("foo", "bar")] return f"foo x bar = {args['foo'] * args['bar']}; unexpected args={unexpected_args}" @@ -237,7 +237,7 @@ # as a result, the schema's behavior (EXCLUDE) is used @app.route("/", methods=["POST"]) @use_args(RectangleSchema(), location="json", unknown=None) - def get(self, args): + def get(args): return f"area = {args['length'] * args['width']}" @@ -275,7 +275,7 @@ @use_args(RectangleSchema) - def post(self, rect: Rectangle): + def post(rect: Rectangle): return f"Area: {rect.length * rect.width}" Packages such as `marshmallow-sqlalchemy `_ and `marshmallow-dataclass `_ generate schemas that deserialize to non-dictionary objects. @@ -435,6 +435,50 @@ structure_dict_pair(r, k, v) return r +Parser pre_load +--------------- + +Similar to ``@pre_load`` decorated hooks on marshmallow Schemas, +:class:`Parser ` classes define a method, +`pre_load ` which can +be overridden to provide per-parser transformations of data. +The only way to make use of `pre_load ` is to +subclass a :class:`Parser ` and provide an +implementation. + +`pre_load ` is given the data fetched from a +location, the schema which will be used, the request object, and the location +name which was requested. For example, to define a ``FlaskParser`` which strips +whitespace from ``form`` and ``query`` data, one could write the following: + +.. code-block:: python + + from webargs.flaskparser import FlaskParser + import typing + + + def _strip_whitespace(value): + if isinstance(value, str): + value = value.strip() + elif isinstance(value, typing.Mapping): + return {k: _strip_whitespace(value[k]) for k in value} + elif isinstance(value, (list, tuple)): + return type(value)(map(_strip_whitespace, value)) + return value + + + class WhitspaceStrippingFlaskParser(FlaskParser): + def pre_load(self, location_data, *, schema, req, location): + if location in ("query", "form"): + return _strip_whitespace(location_data) + return location_data + +Note that `Parser.pre_load ` is run after location +loading but before ``Schema.load`` is called. It can therefore be called on +multiple types of mapping objects, including +:class:`MultiDictProxy `, depending on what the +location loader returns. + Returning HTTP 400 Responses ---------------------------- @@ -493,6 +537,92 @@ """ # ... +Multi-Field Detection +--------------------- + +If a ``List`` field is used to parse data from a location like query parameters -- +where one or multiple values can be passed for a single parameter name -- then +webargs will automatically treat that field as a list and parse multiple values +if present. + +To implement this behavior, webargs will examine schemas for ``marshmallow.fields.List`` +fields. ``List`` fields get unpacked to list values when data is loaded, and +other fields do not. This also applies to fields which inherit from ``List``. + +.. note:: + + In webargs v8, ``Tuple`` will be treated this way as well, in addition to ``List``. + +What if you have a list which should be treated as a "multi-field" but which +does not inherit from ``List``? webargs offers two solutions. +You can add the custom attribute `is_multiple=True` to your field or you +can add your class to your parser's list of `KNOWN_MULTI_FIELDS`. + +First, let's define a "multiplexing field" which takes a string or list of +strings to serve as an example: + +.. code-block:: python + + # a custom field class which can accept values like List(String()) or String() + class CustomMultiplexingField(fields.String): + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str): + return super()._deserialize(value, attr, data, **kwargs) + return [ + self._deserialize(v, attr, data, **kwargs) + for v in value + if isinstance(v, str) + ] + + def _serialize(self, value, attr, **kwargs): + if isinstance(value, str): + return super()._serialize(value, attr, **kwargs) + return [self._serialize(v, attr, **kwargs) for v in value if isinstance(v, str)] + + +If you control the definition of ``CustomMultiplexingField``, you can just add +``is_multiple=True`` to it: + +.. code-block:: python + + # option 1: define the field with is_multiple = True + from webargs.flaskparser import parser + + + class CustomMultiplexingField(fields.Field): + is_multiple = True # <----- this marks this as a multi-field + + ... # as above + +If you don't control the definition of ``CustomMultiplexingField``, for example +because it comes from a library, you can add it to the list of known +multifields: + +.. code-block:: python + + # option 2: add the field to the parer's list of multi-fields + class MyParser(FlaskParser): + KNOWN_MULTI_FIELDS = list(FlaskParser.KNOWN_MULTI_FIELDS) + [ + CustomMultiplexingField + ] + + + parser = MyParser() + +In either case, the end result is that you can use the multifield and it will +be detected as a list when unpacking query string data: + +.. code-block:: python + + # gracefully handles + # ...?foo=a + # ...?foo=a&foo=b + # and treats them as ["a"] and ["a", "b"] respectively + @parser.use_args({"foo": CustomMultiplexingField()}, location="query") + def show_foos(foo): + ... + + Mixing Locations ---------------- diff --git a/setup.py b/setup.py index 101d3a4..fb30a1f 100644 --- a/setup.py +++ b/setup.py @@ -20,12 +20,12 @@ ] + FRAMEWORKS, "lint": [ - "mypy==0.790", - "flake8==3.8.4", - "flake8-bugbear==20.11.1", + "mypy==0.812", + "flake8==3.9.0", + "flake8-bugbear==21.4.3", "pre-commit~=2.4", ], - "docs": ["Sphinx==3.3.1", "sphinx-issues==1.2.0", "sphinx-typlog-theme==0.8.0"] + "docs": ["Sphinx==3.5.3", "sphinx-issues==1.2.0", "sphinx-typlog-theme==0.8.0"] + FRAMEWORKS, } EXTRAS_REQUIRE["dev"] = EXTRAS_REQUIRE["tests"] + EXTRAS_REQUIRE["lint"] + ["tox"] diff --git a/src/webargs/__init__.py b/src/webargs/__init__.py index efadffc..6b53e24 100755 --- a/src/webargs/__init__.py +++ b/src/webargs/__init__.py @@ -7,6 +7,6 @@ from webargs.core import ValidationError from webargs import fields -__version__ = "7.0.1" +__version__ = "8.0.0" __version_info__ = tuple(LooseVersion(__version__).version) __all__ = ("ValidationError", "fields", "missing", "validate") diff --git a/src/webargs/aiohttpparser.py b/src/webargs/aiohttpparser.py index 9478026..b85ff51 100644 --- a/src/webargs/aiohttpparser.py +++ b/src/webargs/aiohttpparser.py @@ -71,7 +71,7 @@ class AIOHTTPParser(AsyncParser): """aiohttp request argument parser.""" - DEFAULT_UNKNOWN_BY_LOCATION = { + DEFAULT_UNKNOWN_BY_LOCATION: typing.Dict[str, typing.Optional[str]] = { "match_info": RAISE, "path": RAISE, **core.Parser.DEFAULT_UNKNOWN_BY_LOCATION, @@ -84,12 +84,12 @@ def load_querystring(self, req, schema: Schema) -> MultiDictProxy: """Return query params from the request as a MultiDictProxy.""" - return MultiDictProxy(req.query, schema) + return self._makeproxy(req.query, schema) async def load_form(self, req, schema: Schema) -> MultiDictProxy: """Return form values from the request as a MultiDictProxy.""" post_data = await req.post() - return MultiDictProxy(post_data, schema) + return self._makeproxy(post_data, schema) async def load_json_or_form( self, req, schema: Schema @@ -114,11 +114,11 @@ def load_headers(self, req, schema: Schema) -> MultiDictProxy: """Return headers from the request as a MultiDictProxy.""" - return MultiDictProxy(req.headers, schema) + return self._makeproxy(req.headers, schema) def load_cookies(self, req, schema: Schema) -> MultiDictProxy: """Return cookies from the request as a MultiDictProxy.""" - return MultiDictProxy(req.cookies, schema) + return self._makeproxy(req.cookies, schema) def load_files(self, req, schema: Schema) -> typing.NoReturn: raise NotImplementedError( diff --git a/src/webargs/bottleparser.py b/src/webargs/bottleparser.py index 3cfd299..dcf2273 100644 --- a/src/webargs/bottleparser.py +++ b/src/webargs/bottleparser.py @@ -19,7 +19,6 @@ import bottle from webargs import core -from webargs.multidictproxy import MultiDictProxy class BottleParser(core.Parser): @@ -49,7 +48,7 @@ def load_querystring(self, req, schema): """Return query params from the request as a MultiDictProxy.""" - return MultiDictProxy(req.query, schema) + return self._makeproxy(req.query, schema) def load_form(self, req, schema): """Return form values from the request as a MultiDictProxy.""" @@ -58,11 +57,11 @@ # TODO: Make this check more specific if core.is_json(req.content_type): return core.missing - return MultiDictProxy(req.forms, schema) + return self._makeproxy(req.forms, schema) def load_headers(self, req, schema): """Return headers from the request as a MultiDictProxy.""" - return MultiDictProxy(req.headers, schema) + return self._makeproxy(req.headers, schema) def load_cookies(self, req, schema): """Return cookies from the request.""" @@ -70,7 +69,7 @@ def load_files(self, req, schema): """Return files from the request as a MultiDictProxy.""" - return MultiDictProxy(req.files, schema) + return self._makeproxy(req.files, schema) def handle_error(self, error, req, schema, *, error_status_code, error_headers): """Handles errors during parsing. Aborts the current request with a diff --git a/src/webargs/core.py b/src/webargs/core.py index 25080ee..e675d77 100644 --- a/src/webargs/core.py +++ b/src/webargs/core.py @@ -8,14 +8,13 @@ from marshmallow import ValidationError from marshmallow.utils import missing -from webargs.fields import DelimitedList +from webargs.multidictproxy import MultiDictProxy logger = logging.getLogger(__name__) __all__ = [ "ValidationError", - "is_multiple", "Parser", "missing", "parse_json", @@ -55,11 +54,6 @@ if obj and not _iscallable(obj): raise ValueError(f"{obj!r} is not callable.") return obj - - -def is_multiple(field: ma.fields.Field) -> bool: - """Return whether or not `field` handles repeated/multi-value arguments.""" - return isinstance(field, ma.fields.List) and not isinstance(field, DelimitedList) def get_mimetype(content_type: str) -> str: @@ -132,10 +126,10 @@ DEFAULT_LOCATION: str = "json" #: Default value to use for 'unknown' on schema load # on a per-location basis - DEFAULT_UNKNOWN_BY_LOCATION: typing.Dict[str, str] = { - "json": ma.RAISE, - "form": ma.RAISE, - "json_or_form": ma.RAISE, + DEFAULT_UNKNOWN_BY_LOCATION: typing.Dict[str, typing.Optional[str]] = { + "json": None, + "form": None, + "json_or_form": None, "querystring": ma.EXCLUDE, "query": ma.EXCLUDE, "headers": ma.EXCLUDE, @@ -148,6 +142,8 @@ DEFAULT_VALIDATION_STATUS: int = DEFAULT_VALIDATION_STATUS #: Default error message for validation errors DEFAULT_VALIDATION_MESSAGE: str = "Invalid value." + #: field types which should always be treated as if they set `is_multiple=True` + KNOWN_MULTI_FIELDS: typing.List[typing.Type] = [ma.fields.List, ma.fields.Tuple] #: Maps location => method name __location_map__: typing.Dict[str, typing.Union[str, typing.Callable]] = { @@ -175,6 +171,12 @@ ) self.schema_class = schema_class or self.DEFAULT_SCHEMA_CLASS self.unknown = unknown + + def _makeproxy( + self, multidict, schema: ma.Schema, cls: typing.Type = MultiDictProxy + ): + """Create a multidict proxy object with options from the current parser""" + return cls(multidict, schema, known_multi_fields=tuple(self.KNOWN_MULTI_FIELDS)) def _get_loader(self, location: str) -> typing.Callable: """Get the loader function for the given location. @@ -320,7 +322,10 @@ location_data = self._load_location_data( schema=schema, req=req, location=location ) - data = schema.load(location_data, **load_kwargs) + preprocessed_data = self.pre_load( + location_data, schema=schema, req=req, location=location + ) + data = schema.load(preprocessed_data, **load_kwargs) self._validate_arguments(data, validators) except ma.exceptions.ValidationError as error: self._on_validation_error( @@ -521,6 +526,15 @@ self.error_callback = func return func + def pre_load( + self, location_data: Mapping, *, schema: ma.Schema, req: Request, location: str + ) -> Mapping: + """A method of the parser which can transform data after location + loading is done. By default it does nothing, but users can subclass + parsers and override this method. + """ + return location_data + def _handle_invalid_json_error( self, error: typing.Union[json.JSONDecodeError, UnicodeDecodeError], diff --git a/src/webargs/djangoparser.py b/src/webargs/djangoparser.py index 73c135c..cef5c28 100644 --- a/src/webargs/djangoparser.py +++ b/src/webargs/djangoparser.py @@ -18,7 +18,6 @@ return HttpResponse('Hello ' + args['name']) """ from webargs import core -from webargs.multidictproxy import MultiDictProxy def is_json_request(req): @@ -48,11 +47,11 @@ def load_querystring(self, req, schema): """Return query params from the request as a MultiDictProxy.""" - return MultiDictProxy(req.GET, schema) + return self._makeproxy(req.GET, schema) def load_form(self, req, schema): """Return form values from the request as a MultiDictProxy.""" - return MultiDictProxy(req.POST, schema) + return self._makeproxy(req.POST, schema) def load_cookies(self, req, schema): """Return cookies from the request.""" @@ -66,7 +65,7 @@ def load_files(self, req, schema): """Return files from the request as a MultiDictProxy.""" - return MultiDictProxy(req.FILES, schema) + return self._makeproxy(req.FILES, schema) def get_request_from_view_args(self, view, args, kwargs): # The first argument is either `self` or `request` diff --git a/src/webargs/falconparser.py b/src/webargs/falconparser.py index d2eb448..cfe1170 100644 --- a/src/webargs/falconparser.py +++ b/src/webargs/falconparser.py @@ -6,7 +6,6 @@ import marshmallow as ma from webargs import core -from webargs.multidictproxy import MultiDictProxy HTTP_422 = "422 Unprocessable Entity" @@ -97,7 +96,7 @@ def load_querystring(self, req, schema): """Return query params from the request as a MultiDictProxy.""" - return MultiDictProxy(req.params, schema) + return self._makeproxy(req.params, schema) def load_form(self, req, schema): """Return form values from the request as a MultiDictProxy @@ -109,7 +108,7 @@ form = parse_form_body(req) if form is core.missing: return form - return MultiDictProxy(form, schema) + return self._makeproxy(form, schema) def load_media(self, req, schema): """Return data unpacked and parsed by one of Falcon's media handlers. diff --git a/src/webargs/fields.py b/src/webargs/fields.py index 806cc5d..f8991d1 100644 --- a/src/webargs/fields.py +++ b/src/webargs/fields.py @@ -55,6 +55,8 @@ """ delimiter: str = "," + # delimited fields set is_multiple=False for webargs.core.is_multiple + is_multiple: bool = False def _serialize(self, value, attr, obj, **kwargs): # serializing will start with parent-class serialization, so that we correctly diff --git a/src/webargs/flaskparser.py b/src/webargs/flaskparser.py index 4fbe15c..053988d 100644 --- a/src/webargs/flaskparser.py +++ b/src/webargs/flaskparser.py @@ -20,13 +20,14 @@ uid=uid, per_page=args["per_page"] ) """ +import typing + import flask from werkzeug.exceptions import HTTPException import marshmallow as ma from webargs import core -from webargs.multidictproxy import MultiDictProxy def abort(http_status_code, exc=None, **kwargs): @@ -50,7 +51,7 @@ class FlaskParser(core.Parser): """Flask request argument parser.""" - DEFAULT_UNKNOWN_BY_LOCATION = { + DEFAULT_UNKNOWN_BY_LOCATION: typing.Dict[str, typing.Optional[str]] = { "view_args": ma.RAISE, "path": ma.RAISE, **core.Parser.DEFAULT_UNKNOWN_BY_LOCATION, @@ -80,15 +81,15 @@ def load_querystring(self, req, schema): """Return query params from the request as a MultiDictProxy.""" - return MultiDictProxy(req.args, schema) + return self._makeproxy(req.args, schema) def load_form(self, req, schema): """Return form values from the request as a MultiDictProxy.""" - return MultiDictProxy(req.form, schema) + return self._makeproxy(req.form, schema) def load_headers(self, req, schema): """Return headers from the request as a MultiDictProxy.""" - return MultiDictProxy(req.headers, schema) + return self._makeproxy(req.headers, schema) def load_cookies(self, req, schema): """Return cookies from the request.""" @@ -96,7 +97,7 @@ def load_files(self, req, schema): """Return files from the request as a MultiDictProxy.""" - return MultiDictProxy(req.files, schema) + return self._makeproxy(req.files, schema) def handle_error(self, error, req, schema, *, error_status_code, error_headers): """Handles errors during parsing. Aborts the current HTTP request and diff --git a/src/webargs/multidictproxy.py b/src/webargs/multidictproxy.py index 19792dc..a277178 100644 --- a/src/webargs/multidictproxy.py +++ b/src/webargs/multidictproxy.py @@ -1,8 +1,7 @@ from collections.abc import Mapping +import typing import marshmallow as ma - -from webargs.core import missing, is_multiple class MultiDictProxy(Mapping): @@ -15,22 +14,39 @@ In all other cases, __getitem__ proxies directly to the input multidict. """ - def __init__(self, multidict, schema: ma.Schema): + def __init__( + self, + multidict, + schema: ma.Schema, + known_multi_fields: typing.Tuple[typing.Type, ...] = ( + ma.fields.List, + ma.fields.Tuple, + ), + ): self.data = multidict + self.known_multi_fields = known_multi_fields self.multiple_keys = self._collect_multiple_keys(schema) - @staticmethod - def _collect_multiple_keys(schema: ma.Schema): + def _is_multiple(self, field: ma.fields.Field) -> bool: + """Return whether or not `field` handles repeated/multi-value arguments.""" + # fields which set `is_multiple = True/False` will have the value selected, + # otherwise, we check for explicit criteria + is_multiple_attr = getattr(field, "is_multiple", None) + if is_multiple_attr is not None: + return is_multiple_attr + return isinstance(field, self.known_multi_fields) + + def _collect_multiple_keys(self, schema: ma.Schema): result = set() for name, field in schema.fields.items(): - if not is_multiple(field): + if not self._is_multiple(field): continue result.add(field.data_key if field.data_key is not None else name) return result def __getitem__(self, key): - val = self.data.get(key, missing) - if val is missing or key not in self.multiple_keys: + val = self.data.get(key, ma.missing) + if val is ma.missing or key not in self.multiple_keys: return val if hasattr(self.data, "getlist"): return self.data.getlist(key) diff --git a/src/webargs/pyramidparser.py b/src/webargs/pyramidparser.py index 9537fb9..4be2884 100644 --- a/src/webargs/pyramidparser.py +++ b/src/webargs/pyramidparser.py @@ -25,6 +25,7 @@ server.serve_forever() """ import functools +import typing from collections.abc import Mapping from webob.multidict import MultiDict @@ -34,7 +35,6 @@ from webargs import core from webargs.core import json -from webargs.multidictproxy import MultiDictProxy def is_json_request(req): @@ -44,7 +44,7 @@ class PyramidParser(core.Parser): """Pyramid request argument parser.""" - DEFAULT_UNKNOWN_BY_LOCATION = { + DEFAULT_UNKNOWN_BY_LOCATION: typing.Dict[str, typing.Optional[str]] = { "matchdict": ma.RAISE, "path": ma.RAISE, **core.Parser.DEFAULT_UNKNOWN_BY_LOCATION, @@ -67,28 +67,28 @@ def load_querystring(self, req, schema): """Return query params from the request as a MultiDictProxy.""" - return MultiDictProxy(req.GET, schema) + return self._makeproxy(req.GET, schema) def load_form(self, req, schema): """Return form values from the request as a MultiDictProxy.""" - return MultiDictProxy(req.POST, schema) + return self._makeproxy(req.POST, schema) def load_cookies(self, req, schema): """Return cookies from the request as a MultiDictProxy.""" - return MultiDictProxy(req.cookies, schema) + return self._makeproxy(req.cookies, schema) def load_headers(self, req, schema): """Return headers from the request as a MultiDictProxy.""" - return MultiDictProxy(req.headers, schema) + return self._makeproxy(req.headers, schema) def load_files(self, req, schema): """Return files from the request as a MultiDictProxy.""" files = ((k, v) for k, v in req.POST.items() if hasattr(v, "file")) - return MultiDictProxy(MultiDict(files), schema) + return self._makeproxy(MultiDict(files), schema) def load_matchdict(self, req, schema): """Return the request's ``matchdict`` as a MultiDictProxy.""" - return MultiDictProxy(req.matchdict, schema) + return self._makeproxy(req.matchdict, schema) def handle_error(self, error, req, schema, *, error_status_code, error_headers): """Handles errors during parsing. Aborts the current HTTP request and diff --git a/src/webargs/tornadoparser.py b/src/webargs/tornadoparser.py index 4c919a0..c5c53cd 100644 --- a/src/webargs/tornadoparser.py +++ b/src/webargs/tornadoparser.py @@ -97,25 +97,31 @@ def load_querystring(self, req, schema): """Return query params from the request as a MultiDictProxy.""" - return WebArgsTornadoMultiDictProxy(req.query_arguments, schema) + return self._makeproxy( + req.query_arguments, schema, cls=WebArgsTornadoMultiDictProxy + ) def load_form(self, req, schema): """Return form values from the request as a MultiDictProxy.""" - return WebArgsTornadoMultiDictProxy(req.body_arguments, schema) + return self._makeproxy( + req.body_arguments, schema, cls=WebArgsTornadoMultiDictProxy + ) def load_headers(self, req, schema): """Return headers from the request as a MultiDictProxy.""" - return WebArgsTornadoMultiDictProxy(req.headers, schema) + return self._makeproxy(req.headers, schema, cls=WebArgsTornadoMultiDictProxy) def load_cookies(self, req, schema): """Return cookies from the request as a MultiDictProxy.""" # use the specialized subclass specifically for handling Tornado # cookies - return WebArgsTornadoCookiesMultiDictProxy(req.cookies, schema) + return self._makeproxy( + req.cookies, schema, cls=WebArgsTornadoCookiesMultiDictProxy + ) def load_files(self, req, schema): """Return files from the request as a MultiDictProxy.""" - return WebArgsTornadoMultiDictProxy(req.files, schema) + return self._makeproxy(req.files, schema, cls=WebArgsTornadoMultiDictProxy) def handle_error(self, error, req, schema, *, error_status_code, error_headers): """Handles errors during parsing. Raises a `tornado.web.HTTPError` diff --git a/tests/test_core.py b/tests/test_core.py index 829a3fe..48b8ca6 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,4 +1,5 @@ import datetime +import typing from unittest import mock import pytest @@ -35,7 +36,10 @@ """A minimal parser implementation that parses mock requests.""" def load_querystring(self, req, schema): - return MultiDictProxy(req.query, schema) + return self._makeproxy(req.query, schema) + + def load_form(self, req, schema): + return MultiDictProxy(req.form, schema) def load_json(self, req, schema): return req.json @@ -1032,6 +1036,96 @@ parser.parse(args, web_request) +@pytest.mark.parametrize("input_dict", multidicts) +@pytest.mark.parametrize( + "setting", + [ + "is_multiple_true", + "is_multiple_false", + "is_multiple_notset", + "list_field", + "tuple_field", + "added_to_known", + ], +) +def test_is_multiple_detection(web_request, parser, input_dict, setting): + # this custom class "multiplexes" in that it can be given a single value or + # list of values -- a single value is treated as a string, and a list of + # values is treated as a list of strings + class CustomMultiplexingField(fields.String): + def _deserialize(self, value, attr, data, **kwargs): + if isinstance(value, str): + return super()._deserialize(value, attr, data, **kwargs) + return [ + self._deserialize(v, attr, data, **kwargs) + for v in value + if isinstance(v, str) + ] + + def _serialize(self, value, attr, **kwargs): + if isinstance(value, str): + return super()._serialize(value, attr, **kwargs) + return [ + self._serialize(v, attr, **kwargs) for v in value if isinstance(v, str) + ] + + class CustomMultipleField(CustomMultiplexingField): + is_multiple = True + + class CustomNonMultipleField(CustomMultiplexingField): + is_multiple = False + + # the request's query params are the input multidict + web_request.query = input_dict + + # case 1: is_multiple=True + if setting == "is_multiple_true": + # the multidict should unpack to a list of strings + # + # order is not necessarily guaranteed by the multidict implementations, but + # both values must be present + args = {"foos": CustomMultipleField()} + result = parser.parse(args, web_request, location="query") + assert result["foos"] in (["a", "b"], ["b", "a"]) + # case 2: is_multiple=False + elif setting == "is_multiple_false": + # the multidict should unpack to a string + # + # either value may be returned, depending on the multidict implementation, + # but not both + args = {"foos": CustomNonMultipleField()} + result = parser.parse(args, web_request, location="query") + assert result["foos"] in ("a", "b") + # case 3: is_multiple is not set + elif setting == "is_multiple_notset": + # this should be the same as is_multiple=False + args = {"foos": CustomMultiplexingField()} + result = parser.parse(args, web_request, location="query") + assert result["foos"] in ("a", "b") + # case 4: the field is a List (special case) + elif setting == "list_field": + # this should behave like the is_multiple=True case + args = {"foos": fields.List(fields.Str())} + result = parser.parse(args, web_request, location="query") + assert result["foos"] in (["a", "b"], ["b", "a"]) + # case 5: the field is a Tuple (special case) + elif setting == "tuple_field": + # this should behave like the is_multiple=True case and produce a tuple + args = {"foos": fields.Tuple((fields.Str, fields.Str))} + result = parser.parse(args, web_request, location="query") + assert result["foos"] in (("a", "b"), ("b", "a")) + # case 6: the field is custom, but added to the known fields of the parser + elif setting == "added_to_known": + # if it's included in the known multifields and is_multiple is not set, behave + # like is_multiple=True + parser.KNOWN_MULTI_FIELDS.append(CustomMultiplexingField) + args = {"foos": CustomMultiplexingField()} + result = parser.parse(args, web_request, location="query") + assert result["foos"] in (["a", "b"], ["b", "a"]) + else: + raise NotImplementedError + + def test_validation_errors_in_validator_are_passed_to_handle_error(parser, web_request): def validate(value): raise ValidationError("Something went wrong.") @@ -1134,3 +1228,84 @@ p = CustomParser() ret = p.parse(argmap, web_request) assert ret == {"value": "hello world"} + + +def test_parser_pre_load(web_request): + class CustomParser(MockRequestParser): + # pre-load hook to strip whitespace from query params + def pre_load(self, data, *, schema, req, location): + if location == "query": + return {k: v.strip() for k, v in data.items()} + return data + + parser = CustomParser() + + # mock data for both query and json + web_request.query = web_request.json = {"value": " hello "} + argmap = {"value": fields.Str()} + + # data gets through for 'json' just fine + ret = parser.parse(argmap, web_request) + assert ret == {"value": " hello "} + + # but for 'query', the pre_load hook changes things + ret = parser.parse(argmap, web_request, location="query") + assert ret == {"value": "hello"} + + +# this test is meant to be a run of the WhitspaceStrippingFlaskParser we give +# in the docs/advanced.rst examples for how to use pre_load +# this helps ensure that the example code is correct +# rather than a FlaskParser, we're working with the mock parser, but it's +# otherwise the same +def test_whitespace_stripping_parser_example(web_request): + def _strip_whitespace(value): + if isinstance(value, str): + value = value.strip() + elif isinstance(value, typing.Mapping): + return {k: _strip_whitespace(value[k]) for k in value} + elif isinstance(value, (list, tuple)): + return type(value)(map(_strip_whitespace, value)) + return value + + class WhitspaceStrippingParser(MockRequestParser): + def pre_load(self, location_data, *, schema, req, location): + if location in ("query", "form"): + ret = _strip_whitespace(location_data) + return ret + return location_data + + parser = WhitspaceStrippingParser() + + # mock data for query, form, and json + web_request.form = web_request.query = web_request.json = {"value": " hello "} + argmap = {"value": fields.Str()} + + # data gets through for 'json' just fine + ret = parser.parse(argmap, web_request) + assert ret == {"value": " hello "} + + # but for 'query' and 'form', the pre_load hook changes things + for loc in ("query", "form"): + ret = parser.parse(argmap, web_request, location=loc) + assert ret == {"value": "hello"} + + # check that it applies in the case where the field is a list type + # applied to an argument (logic for `tuple` is effectively the same) + web_request.form = web_request.query = web_request.json = { + "ids": [" 1", "3", " 4"], + "values": [" foo ", " bar"], + } + schema = Schema.from_dict( + {"ids": fields.List(fields.Int), "values": fields.List(fields.Str)} + ) + for loc in ("query", "form"): + ret = parser.parse(schema, web_request, location=loc) + assert ret == {"ids": [1, 3, 4], "values": ["foo", "bar"]} + + # json loading should also work even though the pre_load hook above + # doesn't strip whitespace from JSON data + # - values=[" foo ", ...] will have whitespace preserved + # - ids=[" 1", ...] will still parse okay because " 1" is valid for fields.Int + ret = parser.parse(schema, web_request, location="json") + assert ret == {"ids": [1, 3, 4], "values": [" foo ", " bar"]} diff --git a/tox.ini b/tox.ini index f3915a2..f18520b 100644 --- a/tox.ini +++ b/tox.ini @@ -30,6 +30,7 @@ # issues in which `mypy` running on every file standalone won't catch things [testenv:mypy] deps = mypy +extras = frameworks commands = mypy src/ [testenv:docs]