Codebase list python-webargs / 5d50e68 webargs / fields.py
5d50e68

Tree @5d50e68 (Download .tar.gz)

fields.py @5d50e68raw · history · blame

# -*- coding: utf-8 -*-
"""Field classes.

Includes all fields from `marshmallow.fields` in addition to a custom
`Nested` field and `DelimitedList`.

All fields can optionally take a special `location` keyword argument, which
tells webargs where to parse the request argument from. ::

    args = {
        'active': fields.Bool(location='query')
        'content_type': fields.Str(data_key='Content-Type',
                                   location='headers')
    }

Note: `data_key` replaced `load_from` in marshmallow 3.
When using marshmallow 2, use `load_from`.
"""
import marshmallow as ma

from webargs.core import dict2schema

__all__ = ["Nested", "DelimitedList"]
# Expose all fields from marshmallow.fields.
# We do this instead of 'from marshmallow.fields import *' because webargs
# has its own subclass of Nested
for each in (field_name for field_name in ma.fields.__all__ if field_name != "Nested"):
    __all__.append(each)
    globals()[each] = getattr(ma.fields, each)


class Nested(ma.fields.Nested):
    """Same as `marshmallow.fields.Nested`, except can be passed a dictionary as
    the first argument, which will be converted to a `marshmallow.Schema`.
    """

    def __init__(self, nested, *args, **kwargs):
        if isinstance(nested, dict):
            nested = dict2schema(nested)
        super(Nested, self).__init__(nested, *args, **kwargs)


class DelimitedList(ma.fields.List):
    """Same as `marshmallow.fields.List`, except can load from either a list or
    a delimited string (e.g. "foo,bar,baz").

    :param Field cls_or_instance: A field class or instance.
    :param str delimiter: Delimiter between values.
    :param bool as_string: Dump values to string.
    """

    delimiter = ","

    def __init__(self, cls_or_instance, delimiter=None, as_string=False, **kwargs):
        self.delimiter = delimiter or self.delimiter
        self.as_string = as_string
        super(DelimitedList, self).__init__(cls_or_instance, **kwargs)

    def _serialize(self, value, attr, obj):
        ret = super(DelimitedList, self)._serialize(value, attr, obj)
        if self.as_string:
            return self.delimiter.join(format(each) for each in ret)
        return ret

    def _deserialize(self, value, attr, data):
        try:
            ret = (
                value
                if ma.utils.is_iterable_but_not_string(value)
                else value.split(self.delimiter)
            )
        except AttributeError:
            self.fail("invalid")
        return super(DelimitedList, self)._deserialize(ret, attr, data)