mirror of
https://github.com/snowykami/mbcp.git
synced 2024-11-25 15:55:03 +08:00
🐛 fix ε accuracy
This commit is contained in:
parent
d9c41cd311
commit
6162606152
@ -15,5 +15,5 @@ PI = math.pi
|
|||||||
E = math.e
|
E = math.e
|
||||||
GOLDEN_RATIO = (1 + math.sqrt(5)) / 2
|
GOLDEN_RATIO = (1 + math.sqrt(5)) / 2
|
||||||
GAMMA = 0.57721566490153286060651209008240243104215933593992
|
GAMMA = 0.57721566490153286060651209008240243104215933593992
|
||||||
EPSILON = 1e-8
|
EPSILON = 0.0000000000001
|
||||||
|
|
||||||
|
@ -63,10 +63,11 @@ def get_partial_derivative_func(func: MultiVarsFunc, var: int | tuple[int, ...],
|
|||||||
return (func(*args_list_plus) - func(*args_list_minus)) / (2 * epsilon)
|
return (func(*args_list_plus) - func(*args_list_minus)) / (2 * epsilon)
|
||||||
return partial_derivative_func
|
return partial_derivative_func
|
||||||
elif isinstance(var, tuple):
|
elif isinstance(var, tuple):
|
||||||
for i in var:
|
def high_order_partial_derivative_func(*args: Var) -> Var:
|
||||||
print("测试第i个变量:", i)
|
result_func = func
|
||||||
func = get_partial_derivative_func(func, i, epsilon)
|
for v in var:
|
||||||
print("测试第i个变量的偏导:", func(1, 2))
|
result_func = get_partial_derivative_func(result_func, v, epsilon)
|
||||||
return func
|
return result_func(*args)
|
||||||
|
return high_order_partial_derivative_func
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid var type")
|
raise ValueError("Invalid var type")
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from mbcp.mp_math.mp_math_typing import RealNumber
|
from mbcp.mp_math.mp_math_typing import RealNumber
|
||||||
|
from mbcp.mp_math.utils import Approx
|
||||||
|
|
||||||
|
|
||||||
def three_var_func(x: RealNumber, y: RealNumber) -> RealNumber:
|
def three_var_func(x: RealNumber, y: RealNumber) -> RealNumber:
|
||||||
@ -26,6 +27,7 @@ class TestPartialDerivative:
|
|||||||
def df_dx(x, y):
|
def df_dx(x, y):
|
||||||
"""原函数关于x的偏导"""
|
"""原函数关于x的偏导"""
|
||||||
return 3 * (x ** 2) * (y ** 2) - 3 * (y ** 3) - y
|
return 3 * (x ** 2) * (y ** 2) - 3 * (y ** 3) - y
|
||||||
|
|
||||||
logging.info(f"Expected: {df_dx(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
logging.info(f"Expected: {df_dx(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
assert Approx(partial_derivative_func(1, 2)) == df_dx(1, 2)
|
assert Approx(partial_derivative_func(1, 2)) == df_dx(1, 2)
|
||||||
|
|
||||||
@ -40,6 +42,7 @@ class TestPartialDerivative:
|
|||||||
def df_dy(x, y):
|
def df_dy(x, y):
|
||||||
"""原函数关于y的偏导"""
|
"""原函数关于y的偏导"""
|
||||||
return 2 * (x ** 3) * y - 9 * x * (y ** 2) - x
|
return 2 * (x ** 3) * y - 9 * x * (y ** 2) - x
|
||||||
|
|
||||||
logging.info(f"Expected: {df_dy(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
logging.info(f"Expected: {df_dy(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
assert Approx(partial_derivative_func(1, 2)) == df_dy(1, 2)
|
assert Approx(partial_derivative_func(1, 2)) == df_dy(1, 2)
|
||||||
|
|
||||||
@ -54,6 +57,7 @@ class TestPartialDerivative:
|
|||||||
def df_dxdy(x, y):
|
def df_dxdy(x, y):
|
||||||
"""原函数关于y和x的偏导"""
|
"""原函数关于y和x的偏导"""
|
||||||
return 6 * x ** 2 * y - 9 * y ** 2 - 1
|
return 6 * x ** 2 * y - 9 * y ** 2 - 1
|
||||||
|
|
||||||
logging.info(f"Expected: {df_dxdy(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
logging.info(f"Expected: {df_dxdy(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
assert Approx(partial_derivative_func(1, 2)) == df_dxdy(1, 2)
|
assert Approx(partial_derivative_func(1, 2)) == df_dxdy(1, 2)
|
||||||
|
|
||||||
@ -63,11 +67,12 @@ class TestPartialDerivative:
|
|||||||
from mbcp.mp_math.utils import Approx
|
from mbcp.mp_math.utils import Approx
|
||||||
from mbcp.mp_math.equation import get_partial_derivative_func
|
from mbcp.mp_math.equation import get_partial_derivative_func
|
||||||
|
|
||||||
partial_derivative_func = get_partial_derivative_func(three_var_func, (0 , 0))
|
partial_derivative_func = get_partial_derivative_func(three_var_func, (0, 0))
|
||||||
|
|
||||||
def df_dydx(x, y):
|
def df_dydx(x, y):
|
||||||
"""原函数关于x和y的偏导"""
|
"""原函数关于x和y的偏导"""
|
||||||
return 6 * x * y ** 2
|
return 6 * x * y ** 2
|
||||||
|
|
||||||
logging.info(f"Expected: {df_dydx(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
logging.info(f"Expected: {df_dydx(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
assert Approx(partial_derivative_func(1, 2)) == df_dydx(1, 2)
|
assert Approx(partial_derivative_func(1, 2)) == df_dydx(1, 2)
|
||||||
|
|
||||||
@ -82,5 +87,15 @@ class TestPartialDerivative:
|
|||||||
def d3f_dx3(x, y):
|
def d3f_dx3(x, y):
|
||||||
"""原函数关于x的三阶偏导"""
|
"""原函数关于x的三阶偏导"""
|
||||||
return 6 * (y ** 2)
|
return 6 * (y ** 2)
|
||||||
|
|
||||||
logging.info(f"Expected: {d3f_dx3(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
logging.info(f"Expected: {d3f_dx3(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
assert Approx(partial_derivative_func(1, 2)) == d3f_dx3(1, 2)
|
assert Approx(partial_derivative_func(1, 2)) == d3f_dx3(1, 2)
|
||||||
|
|
||||||
|
def test_possible_error(self):
|
||||||
|
from mbcp.mp_math.equation import get_partial_derivative_func
|
||||||
|
def two_vars_func(x: RealNumber, y: RealNumber) -> RealNumber:
|
||||||
|
return x ** 2 * y ** 2
|
||||||
|
|
||||||
|
partial_func = get_partial_derivative_func(two_vars_func, 0)
|
||||||
|
partial_func_2 = get_partial_derivative_func(two_vars_func, (0, 0))
|
||||||
|
assert Approx(partial_func_2(1, 2)) == 8
|
||||||
|
Loading…
Reference in New Issue
Block a user