新增函数柯里化功能

This commit is contained in:
远野千束 2024-08-29 18:52:19 +08:00
parent 807a240c49
commit f136c1a9fc

View File

@ -49,7 +49,7 @@ def get_partial_derivative_func(func: MultiVarsFunc, var: int | tuple[int, ...],
""" """
求N元函数一阶偏导函数这玩意不太稳定慎用 求N元函数一阶偏导函数这玩意不太稳定慎用
> [!warning] > [!warning]
> 目前数学界对于数值微分的稳定性问题还没有很好的解决方案因此这个函数的稳定性也不是很好 > 目前数学界对于一个函数的导函数并没有通解的说法因此该函数的稳定性有待提升
Args: Args:
func: 函数 func: 函数
@ -76,6 +76,7 @@ def get_partial_derivative_func(func: MultiVarsFunc, var: int | tuple[int, ...],
args_list_minus = list(args) args_list_minus = list(args)
args_list_minus[var] -= epsilon args_list_minus[var] -= epsilon
return (func(*args_list_plus) - func(*args_list_minus)) / (2 * epsilon) return (func(*args_list_plus) - func(*args_list_minus)) / (2 * epsilon)
return partial_derivative_func return partial_derivative_func
elif isinstance(var, tuple): elif isinstance(var, tuple):
def high_order_partial_derivative_func(*args: Var) -> Var: def high_order_partial_derivative_func(*args: Var) -> Var:
@ -84,6 +85,24 @@ def get_partial_derivative_func(func: MultiVarsFunc, var: int | tuple[int, ...],
for v in var: for v in var:
result_func = get_partial_derivative_func(result_func, v, epsilon) result_func = get_partial_derivative_func(result_func, v, epsilon)
return result_func(*args) return result_func(*args)
return high_order_partial_derivative_func return high_order_partial_derivative_func
else: else:
raise ValueError("Invalid var type") raise ValueError("Invalid var type")
def curry(func: MultiVarsFunc, *args: Var) -> OneVarFunc:
"""
对多参数函数进行柯里化
> [!tip]
> 有关函数柯里化可参考[函数式编程--柯理化Currying](https://zhuanlan.zhihu.com/p/355859667)
Args:
func: 函数
*args: 参数
Returns:
柯里化后的函数
"""
def curried_func(*args2: Var) -> Var:
"""@litedoc-hide"""
return func(*args, *args2)
return curried_func