add partial derivative

This commit is contained in:
远野千束 2024-08-26 00:11:02 +08:00
parent cc06c34967
commit d9c41cd311
3 changed files with 15 additions and 13 deletions

View File

@ -9,7 +9,7 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Software: PyCharm @Software: PyCharm
""" """
from mbcp.mp_math.mp_math_typing import OneVarFunc, Var, MultiVarFunc, Number from mbcp.mp_math.mp_math_typing import OneVarFunc, Var, MultiVarsFunc, Number
from mbcp.mp_math.point import Point3 from mbcp.mp_math.point import Point3
from mbcp.mp_math.const import EPSILON from mbcp.mp_math.const import EPSILON
@ -44,9 +44,9 @@ class CurveEquation:
return "CurveEquation()" return "CurveEquation()"
def get_partial_derivative_func(func: MultiVarFunc, var: int | tuple[int, ...], epsilon: Number = EPSILON) -> MultiVarFunc: def get_partial_derivative_func(func: MultiVarsFunc, var: int | tuple[int, ...], epsilon: Number = EPSILON) -> MultiVarsFunc:
""" """
求N元函数偏导函数 求N元函数一阶偏导函数
Args: Args:
func: 函数 func: 函数
var: 变量位置可为整数(一阶偏导)或整数元组(高阶偏导) var: 变量位置可为整数(一阶偏导)或整数元组(高阶偏导)
@ -64,7 +64,9 @@ def get_partial_derivative_func(func: MultiVarFunc, var: int | tuple[int, ...],
return partial_derivative_func return partial_derivative_func
elif isinstance(var, tuple): elif isinstance(var, tuple):
for i in var: for i in var:
print("测试第i个变量", i)
func = get_partial_derivative_func(func, i, epsilon) func = get_partial_derivative_func(func, i, epsilon)
print("测试第i个变量的偏导", func(1, 2))
return func return func
else: else:
raise ValueError("Invalid var type") raise ValueError("Invalid var type")

View File

@ -20,14 +20,14 @@ OneSingleVarFunc: TypeAlias = Callable[[SingleVar], SingleVar]
OneArrayFunc: TypeAlias = Callable[[ArrayVar], ArrayVar] OneArrayFunc: TypeAlias = Callable[[ArrayVar], ArrayVar]
OneVarFunc: TypeAlias = OneSingleVarFunc | OneArrayFunc OneVarFunc: TypeAlias = OneSingleVarFunc | OneArrayFunc
TwoSingleVarFunc: TypeAlias = Callable[[SingleVar, SingleVar], SingleVar] TwoSingleVarsFunc: TypeAlias = Callable[[SingleVar, SingleVar], SingleVar]
TwoArrayFunc: TypeAlias = Callable[[ArrayVar, ArrayVar], ArrayVar] TwoArraysFunc: TypeAlias = Callable[[ArrayVar, ArrayVar], ArrayVar]
TwoVarFunc: TypeAlias = TwoSingleVarFunc | TwoArrayFunc TwoVarsFunc: TypeAlias = TwoSingleVarsFunc | TwoArraysFunc
ThreeSingleVarFunc: TypeAlias = Callable[[SingleVar, SingleVar, SingleVar], SingleVar] ThreeSingleVarsFunc: TypeAlias = Callable[[SingleVar, SingleVar, SingleVar], SingleVar]
ThreeArrayFunc: TypeAlias = Callable[[ArrayVar, ArrayVar, ArrayVar], ArrayVar] ThreeArraysFunc: TypeAlias = Callable[[ArrayVar, ArrayVar, ArrayVar], ArrayVar]
ThreeVarFunc: TypeAlias = ThreeSingleVarFunc | ThreeArrayFunc ThreeVarsFunc: TypeAlias = ThreeSingleVarsFunc | ThreeArraysFunc
MultiSingleVarFunc: TypeAlias = Callable[..., SingleVar] MultiSingleVarsFunc: TypeAlias = Callable[..., SingleVar]
MultiArrayFunc: TypeAlias = Callable[..., ArrayVar] MultiArraysFunc: TypeAlias = Callable[..., ArrayVar]
MultiVarFunc: TypeAlias = MultiSingleVarFunc | MultiArrayFunc MultiVarsFunc: TypeAlias = MultiSingleVarsFunc | MultiArraysFunc

View File

@ -20,7 +20,7 @@ class TestPartialDerivative:
from mbcp.mp_math.utils import Approx from mbcp.mp_math.utils import Approx
from mbcp.mp_math.equation import get_partial_derivative_func from mbcp.mp_math.equation import get_partial_derivative_func
partial_derivative_func = get_partial_derivative_func(three_var_func, (0,)) partial_derivative_func = get_partial_derivative_func(three_var_func, 0)
# assert partial_derivative_func(1, 2, 3) == 4.0 # assert partial_derivative_func(1, 2, 3) == 4.0
def df_dx(x, y): def df_dx(x, y):