Use functions to implement validators

This commit is contained in:
Richard Chien 2019-01-25 12:57:28 +08:00
parent f8ecc7bba1
commit 0079cd1876

View File

@ -1,7 +1,7 @@
import re
from typing import Callable, Any
from nonebot.command.argfilter import ValidateError
from nonebot.command.argfilter import ValidateError, ArgFilter_T
class BaseValidator:
@ -12,90 +12,88 @@ class BaseValidator:
raise ValidateError(self.message)
class not_empty(BaseValidator):
def _raise_failure(message):
raise ValidateError(message)
def not_empty(message=None) -> ArgFilter_T:
"""
Validate any object to ensure it's not empty (is None or has no elements).
"""
def __call__(self, value):
def validate(value):
if value is None:
self.raise_failure()
_raise_failure(message)
if hasattr(value, '__len__') and value.__len__() == 0:
self.raise_failure()
_raise_failure(message)
return value
return validate
class fit_size(BaseValidator):
def fit_size(min_length: int = 0, max_length: int = None,
message=None) -> ArgFilter_T:
"""
Validate any sized object to ensure the size/length
is in a given range [min_length, max_length].
"""
def __init__(self, min_length: int = 0, max_length: int = None,
message=None):
super().__init__(message)
self.min_length = min_length
self.max_length = max_length
def __call__(self, value):
def validate(value):
length = len(value) if value is not None else 0
if length < self.min_length or \
(self.max_length is not None and length > self.max_length):
self.raise_failure()
if length < min_length or \
(max_length is not None and length > max_length):
_raise_failure(message)
return value
return validate
class match_regex(BaseValidator):
def match_regex(pattern: str, message=None, *, flags=0,
fullmatch: bool = False) -> ArgFilter_T:
"""
Validate any string object to ensure it matches a given pattern.
"""
def __init__(self, pattern: str, message=None, *, flags=0,
fullmatch: bool = False):
super().__init__(message)
self.pattern = re.compile(pattern, flags)
self.fullmatch = fullmatch
pattern = re.compile(pattern, flags)
def __call__(self, value):
if self.fullmatch:
if not re.fullmatch(self.pattern, value):
self.raise_failure()
def validate(value):
if fullmatch:
if not re.fullmatch(pattern, value):
_raise_failure(message)
else:
if not re.match(self.pattern, value):
self.raise_failure()
if not re.match(pattern, value):
_raise_failure(message)
return value
return validate
class ensure_true(BaseValidator):
def ensure_true(bool_func: Callable[[Any], bool],
message=None) -> ArgFilter_T:
"""
Validate any object to ensure the result of applying
a boolean function to it is True.
"""
def __init__(self, bool_func: Callable[[Any], bool], message=None):
super().__init__(message)
self.bool_func = bool_func
def __call__(self, value):
if self.bool_func(value) is not True:
self.raise_failure()
def validate(value):
if bool_func(value) is not True:
_raise_failure(message)
return value
return validate
class between_inclusive(BaseValidator):
def between_inclusive(start=None, end=None, message=None) -> ArgFilter_T:
"""
Validate any comparable object to ensure it's between
`start` and `end` inclusively.
"""
def __init__(self, start=None, end=None, message=None):
super().__init__(message)
self.start = start
self.end = end
def __call__(self, value):
if self.start is not None and value < self.start:
self.raise_failure()
if self.end is not None and self.end < value:
self.raise_failure()
def validate(value):
if start is not None and value < start:
_raise_failure(message)
if end is not None and end < value:
_raise_failure(message)
return value
return validate