Use functions to implement validators

This commit is contained in:
Richard Chien 2019-01-25 12:57:28 +08:00
parent f8ecc7bba1
commit 0079cd1876

View File

@ -1,7 +1,7 @@
import re import re
from typing import Callable, Any from typing import Callable, Any
from nonebot.command.argfilter import ValidateError from nonebot.command.argfilter import ValidateError, ArgFilter_T
class BaseValidator: class BaseValidator:
@ -12,90 +12,88 @@ class BaseValidator:
raise ValidateError(self.message) raise ValidateError(self.message)
class not_empty(BaseValidator): def _raise_failure(message):
raise ValidateError(message)
def not_empty(message=None) -> ArgFilter_T:
""" """
Validate any object to ensure it's not empty (is None or has no elements). Validate any object to ensure it's not empty (is None or has no elements).
""" """
def __call__(self, value): def validate(value):
if value is None: if value is None:
self.raise_failure() _raise_failure(message)
if hasattr(value, '__len__') and value.__len__() == 0: if hasattr(value, '__len__') and value.__len__() == 0:
self.raise_failure() _raise_failure(message)
return value return value
return validate
class fit_size(BaseValidator):
def fit_size(min_length: int = 0, max_length: int = None,
message=None) -> ArgFilter_T:
""" """
Validate any sized object to ensure the size/length Validate any sized object to ensure the size/length
is in a given range [min_length, max_length]. is in a given range [min_length, max_length].
""" """
def __init__(self, min_length: int = 0, max_length: int = None, def validate(value):
message=None):
super().__init__(message)
self.min_length = min_length
self.max_length = max_length
def __call__(self, value):
length = len(value) if value is not None else 0 length = len(value) if value is not None else 0
if length < self.min_length or \ if length < min_length or \
(self.max_length is not None and length > self.max_length): (max_length is not None and length > max_length):
self.raise_failure() _raise_failure(message)
return value return value
return validate
class match_regex(BaseValidator):
def match_regex(pattern: str, message=None, *, flags=0,
fullmatch: bool = False) -> ArgFilter_T:
""" """
Validate any string object to ensure it matches a given pattern. Validate any string object to ensure it matches a given pattern.
""" """
def __init__(self, pattern: str, message=None, *, flags=0, pattern = re.compile(pattern, flags)
fullmatch: bool = False):
super().__init__(message)
self.pattern = re.compile(pattern, flags)
self.fullmatch = fullmatch
def __call__(self, value): def validate(value):
if self.fullmatch: if fullmatch:
if not re.fullmatch(self.pattern, value): if not re.fullmatch(pattern, value):
self.raise_failure() _raise_failure(message)
else: else:
if not re.match(self.pattern, value): if not re.match(pattern, value):
self.raise_failure() _raise_failure(message)
return value return value
return validate
class ensure_true(BaseValidator):
def ensure_true(bool_func: Callable[[Any], bool],
message=None) -> ArgFilter_T:
""" """
Validate any object to ensure the result of applying Validate any object to ensure the result of applying
a boolean function to it is True. a boolean function to it is True.
""" """
def __init__(self, bool_func: Callable[[Any], bool], message=None): def validate(value):
super().__init__(message) if bool_func(value) is not True:
self.bool_func = bool_func _raise_failure(message)
def __call__(self, value):
if self.bool_func(value) is not True:
self.raise_failure()
return value return value
return validate
class between_inclusive(BaseValidator):
def between_inclusive(start=None, end=None, message=None) -> ArgFilter_T:
""" """
Validate any comparable object to ensure it's between Validate any comparable object to ensure it's between
`start` and `end` inclusively. `start` and `end` inclusively.
""" """
def __init__(self, start=None, end=None, message=None): def validate(value):
super().__init__(message) if start is not None and value < start:
self.start = start _raise_failure(message)
self.end = end if end is not None and end < value:
_raise_failure(message)
def __call__(self, value):
if self.start is not None and value < self.start:
self.raise_failure()
if self.end is not None and self.end < value:
self.raise_failure()
return value return value
return validate