From 6162606152a74dd304815210a1db2e1bb0cd6c48 Mon Sep 17 00:00:00 2001 From: snowy Date: Mon, 26 Aug 2024 01:31:05 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20fix=20=CE=B5=20accuracy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mbcp/mp_math/const.py | 2 +- mbcp/mp_math/equation.py | 11 ++++++----- tests/test_partial_derivative.py | 17 ++++++++++++++++- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/mbcp/mp_math/const.py b/mbcp/mp_math/const.py index 69f0ae7..70f6178 100644 --- a/mbcp/mp_math/const.py +++ b/mbcp/mp_math/const.py @@ -15,5 +15,5 @@ PI = math.pi E = math.e GOLDEN_RATIO = (1 + math.sqrt(5)) / 2 GAMMA = 0.57721566490153286060651209008240243104215933593992 -EPSILON = 1e-8 +EPSILON = 0.0000000000001 diff --git a/mbcp/mp_math/equation.py b/mbcp/mp_math/equation.py index 35e5be7..efecc63 100644 --- a/mbcp/mp_math/equation.py +++ b/mbcp/mp_math/equation.py @@ -63,10 +63,11 @@ def get_partial_derivative_func(func: MultiVarsFunc, var: int | tuple[int, ...], return (func(*args_list_plus) - func(*args_list_minus)) / (2 * epsilon) return partial_derivative_func elif isinstance(var, tuple): - for i in var: - print("测试第i个变量:", i) - func = get_partial_derivative_func(func, i, epsilon) - print("测试第i个变量的偏导:", func(1, 2)) - return func + def high_order_partial_derivative_func(*args: Var) -> Var: + result_func = func + for v in var: + result_func = get_partial_derivative_func(result_func, v, epsilon) + return result_func(*args) + return high_order_partial_derivative_func else: raise ValueError("Invalid var type") diff --git a/tests/test_partial_derivative.py b/tests/test_partial_derivative.py index 37896d0..588daa1 100644 --- a/tests/test_partial_derivative.py +++ b/tests/test_partial_derivative.py @@ -6,6 +6,7 @@ import logging from mbcp.mp_math.mp_math_typing import RealNumber +from mbcp.mp_math.utils import Approx def three_var_func(x: RealNumber, y: RealNumber) -> RealNumber: @@ -26,6 +27,7 @@ class TestPartialDerivative: def df_dx(x, y): """原函数关于x的偏导""" return 3 * (x ** 2) * (y ** 2) - 3 * (y ** 3) - y + logging.info(f"Expected: {df_dx(1, 2)}, Actual: {partial_derivative_func(1, 2)}") assert Approx(partial_derivative_func(1, 2)) == df_dx(1, 2) @@ -40,6 +42,7 @@ class TestPartialDerivative: def df_dy(x, y): """原函数关于y的偏导""" return 2 * (x ** 3) * y - 9 * x * (y ** 2) - x + logging.info(f"Expected: {df_dy(1, 2)}, Actual: {partial_derivative_func(1, 2)}") assert Approx(partial_derivative_func(1, 2)) == df_dy(1, 2) @@ -54,6 +57,7 @@ class TestPartialDerivative: def df_dxdy(x, y): """原函数关于y和x的偏导""" return 6 * x ** 2 * y - 9 * y ** 2 - 1 + logging.info(f"Expected: {df_dxdy(1, 2)}, Actual: {partial_derivative_func(1, 2)}") assert Approx(partial_derivative_func(1, 2)) == df_dxdy(1, 2) @@ -63,11 +67,12 @@ class TestPartialDerivative: from mbcp.mp_math.utils import Approx from mbcp.mp_math.equation import get_partial_derivative_func - partial_derivative_func = get_partial_derivative_func(three_var_func, (0 , 0)) + partial_derivative_func = get_partial_derivative_func(three_var_func, (0, 0)) def df_dydx(x, y): """原函数关于x和y的偏导""" return 6 * x * y ** 2 + logging.info(f"Expected: {df_dydx(1, 2)}, Actual: {partial_derivative_func(1, 2)}") assert Approx(partial_derivative_func(1, 2)) == df_dydx(1, 2) @@ -82,5 +87,15 @@ class TestPartialDerivative: def d3f_dx3(x, y): """原函数关于x的三阶偏导""" return 6 * (y ** 2) + logging.info(f"Expected: {d3f_dx3(1, 2)}, Actual: {partial_derivative_func(1, 2)}") assert Approx(partial_derivative_func(1, 2)) == d3f_dx3(1, 2) + + def test_possible_error(self): + from mbcp.mp_math.equation import get_partial_derivative_func + def two_vars_func(x: RealNumber, y: RealNumber) -> RealNumber: + return x ** 2 * y ** 2 + + partial_func = get_partial_derivative_func(two_vars_func, 0) + partial_func_2 = get_partial_derivative_func(two_vars_func, (0, 0)) + assert Approx(partial_func_2(1, 2)) == 8