From d9c41cd311011c395c6d53a9e92423688400d081 Mon Sep 17 00:00:00 2001 From: snowy Date: Mon, 26 Aug 2024 00:11:02 +0800 Subject: [PATCH] :zap: add partial derivative --- mbcp/mp_math/equation.py | 8 +++++--- mbcp/mp_math/mp_math_typing.py | 18 +++++++++--------- tests/test_partial_derivative.py | 2 +- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/mbcp/mp_math/equation.py b/mbcp/mp_math/equation.py index 85ab9aa..35e5be7 100644 --- a/mbcp/mp_math/equation.py +++ b/mbcp/mp_math/equation.py @@ -9,7 +9,7 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved @Software: PyCharm """ -from mbcp.mp_math.mp_math_typing import OneVarFunc, Var, MultiVarFunc, Number +from mbcp.mp_math.mp_math_typing import OneVarFunc, Var, MultiVarsFunc, Number from mbcp.mp_math.point import Point3 from mbcp.mp_math.const import EPSILON @@ -44,9 +44,9 @@ class CurveEquation: return "CurveEquation()" -def get_partial_derivative_func(func: MultiVarFunc, var: int | tuple[int, ...], epsilon: Number = EPSILON) -> MultiVarFunc: +def get_partial_derivative_func(func: MultiVarsFunc, var: int | tuple[int, ...], epsilon: Number = EPSILON) -> MultiVarsFunc: """ - 求N元函数偏导函数。 + 求N元函数一阶偏导函数。 Args: func: 函数 var: 变量位置,可为整数(一阶偏导)或整数元组(高阶偏导) @@ -64,7 +64,9 @@ def get_partial_derivative_func(func: MultiVarFunc, var: int | tuple[int, ...], return partial_derivative_func elif isinstance(var, tuple): for i in var: + print("测试第i个变量:", i) func = get_partial_derivative_func(func, i, epsilon) + print("测试第i个变量的偏导:", func(1, 2)) return func else: raise ValueError("Invalid var type") diff --git a/mbcp/mp_math/mp_math_typing.py b/mbcp/mp_math/mp_math_typing.py index 1acc84d..c7285c3 100644 --- a/mbcp/mp_math/mp_math_typing.py +++ b/mbcp/mp_math/mp_math_typing.py @@ -20,14 +20,14 @@ OneSingleVarFunc: TypeAlias = Callable[[SingleVar], SingleVar] OneArrayFunc: TypeAlias = Callable[[ArrayVar], ArrayVar] OneVarFunc: TypeAlias = OneSingleVarFunc | OneArrayFunc -TwoSingleVarFunc: TypeAlias = Callable[[SingleVar, SingleVar], SingleVar] -TwoArrayFunc: TypeAlias = Callable[[ArrayVar, ArrayVar], ArrayVar] -TwoVarFunc: TypeAlias = TwoSingleVarFunc | TwoArrayFunc +TwoSingleVarsFunc: TypeAlias = Callable[[SingleVar, SingleVar], SingleVar] +TwoArraysFunc: TypeAlias = Callable[[ArrayVar, ArrayVar], ArrayVar] +TwoVarsFunc: TypeAlias = TwoSingleVarsFunc | TwoArraysFunc -ThreeSingleVarFunc: TypeAlias = Callable[[SingleVar, SingleVar, SingleVar], SingleVar] -ThreeArrayFunc: TypeAlias = Callable[[ArrayVar, ArrayVar, ArrayVar], ArrayVar] -ThreeVarFunc: TypeAlias = ThreeSingleVarFunc | ThreeArrayFunc +ThreeSingleVarsFunc: TypeAlias = Callable[[SingleVar, SingleVar, SingleVar], SingleVar] +ThreeArraysFunc: TypeAlias = Callable[[ArrayVar, ArrayVar, ArrayVar], ArrayVar] +ThreeVarsFunc: TypeAlias = ThreeSingleVarsFunc | ThreeArraysFunc -MultiSingleVarFunc: TypeAlias = Callable[..., SingleVar] -MultiArrayFunc: TypeAlias = Callable[..., ArrayVar] -MultiVarFunc: TypeAlias = MultiSingleVarFunc | MultiArrayFunc +MultiSingleVarsFunc: TypeAlias = Callable[..., SingleVar] +MultiArraysFunc: TypeAlias = Callable[..., ArrayVar] +MultiVarsFunc: TypeAlias = MultiSingleVarsFunc | MultiArraysFunc diff --git a/tests/test_partial_derivative.py b/tests/test_partial_derivative.py index 6afe85e..37896d0 100644 --- a/tests/test_partial_derivative.py +++ b/tests/test_partial_derivative.py @@ -20,7 +20,7 @@ class TestPartialDerivative: from mbcp.mp_math.utils import Approx from mbcp.mp_math.equation import get_partial_derivative_func - partial_derivative_func = get_partial_derivative_func(three_var_func, (0,)) + partial_derivative_func = get_partial_derivative_func(three_var_func, 0) # assert partial_derivative_func(1, 2, 3) == 4.0 def df_dx(x, y):