mirror of
https://github.com/snowykami/mbcp.git
synced 2024-11-22 14:17:38 +08:00
⚡ add partial derivative
This commit is contained in:
parent
90fcee2ff9
commit
cc06c34967
19
mbcp/mp_math/const.py
Normal file
19
mbcp/mp_math/const.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
||||||
|
|
||||||
|
@Time : 2024/8/25 下午9:45
|
||||||
|
@Author : snowykami
|
||||||
|
@Email : snowykami@outlook.com
|
||||||
|
@File : const.py
|
||||||
|
@Software: PyCharm
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
PI = math.pi
|
||||||
|
E = math.e
|
||||||
|
GOLDEN_RATIO = (1 + math.sqrt(5)) / 2
|
||||||
|
GAMMA = 0.57721566490153286060651209008240243104215933593992
|
||||||
|
EPSILON = 1e-8
|
||||||
|
|
@ -8,13 +8,14 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
|||||||
@File : equation.py
|
@File : equation.py
|
||||||
@Software: PyCharm
|
@Software: PyCharm
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from .point import Point3
|
from mbcp.mp_math.mp_math_typing import OneVarFunc, Var, MultiVarFunc, Number
|
||||||
from .mp_math_typing import ONE_VARIABLE_FUNC, TWO_VARIABLES_FUNC, THREE_VARIABLES_FUNC
|
from mbcp.mp_math.point import Point3
|
||||||
|
from mbcp.mp_math.const import EPSILON
|
||||||
|
|
||||||
|
|
||||||
class CurveEquation:
|
class CurveEquation:
|
||||||
def __init__(self, x_func: ONE_VARIABLE_FUNC, y_func: ONE_VARIABLE_FUNC, z_func: ONE_VARIABLE_FUNC):
|
def __init__(self, x_func: OneVarFunc, y_func: OneVarFunc, z_func: OneVarFunc):
|
||||||
"""
|
"""
|
||||||
曲线方程。
|
曲线方程。
|
||||||
:param x_func:
|
:param x_func:
|
||||||
@ -25,14 +26,45 @@ class CurveEquation:
|
|||||||
self.y_func = y_func
|
self.y_func = y_func
|
||||||
self.z_func = z_func
|
self.z_func = z_func
|
||||||
|
|
||||||
def __call__(self, *t: float) -> "Point3" | tuple["Point3"]:
|
def __call__(self, *t: Var) -> Point3 | tuple[Point3, ...]:
|
||||||
|
"""
|
||||||
|
计算曲线上的点。
|
||||||
|
Args:
|
||||||
|
*t:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
if len(t) == 1:
|
if len(t) == 1:
|
||||||
return Point3(self.x_func(t[0]), self.y_func(t[0]), self.z_func(t[0]))
|
return Point3(self.x_func(t[0]), self.y_func(t[0]), self.z_func(t[0]))
|
||||||
else:
|
else:
|
||||||
# np加速
|
return tuple([Point3(x, y, z) for x, y, z in zip(self.x_func(t), self.y_func(t), self.z_func(t))])
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "CurveEquation()"
|
return "CurveEquation()"
|
||||||
|
|
||||||
|
|
||||||
|
def get_partial_derivative_func(func: MultiVarFunc, var: int | tuple[int, ...], epsilon: Number = EPSILON) -> MultiVarFunc:
|
||||||
|
"""
|
||||||
|
求N元函数偏导函数。
|
||||||
|
Args:
|
||||||
|
func: 函数
|
||||||
|
var: 变量位置,可为整数(一阶偏导)或整数元组(高阶偏导)
|
||||||
|
epsilon: 偏移量
|
||||||
|
Returns:
|
||||||
|
偏导函数
|
||||||
|
"""
|
||||||
|
if isinstance(var, int):
|
||||||
|
def partial_derivative_func(*args: Var) -> Var:
|
||||||
|
args_list_plus = list(args)
|
||||||
|
args_list_plus[var] += epsilon
|
||||||
|
args_list_minus = list(args)
|
||||||
|
args_list_minus[var] -= epsilon
|
||||||
|
return (func(*args_list_plus) - func(*args_list_minus)) / (2 * epsilon)
|
||||||
|
return partial_derivative_func
|
||||||
|
elif isinstance(var, tuple):
|
||||||
|
for i in var:
|
||||||
|
func = get_partial_derivative_func(func, i, epsilon)
|
||||||
|
return func
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid var type")
|
||||||
|
@ -8,11 +8,26 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
|||||||
@File : mp_math_typing.py
|
@File : mp_math_typing.py
|
||||||
@Software: PyCharm
|
@Software: PyCharm
|
||||||
"""
|
"""
|
||||||
from typing import Callable, Iterable, TypeAlias
|
from typing import Callable, Iterable, TypeAlias, TypeVar
|
||||||
|
|
||||||
"""自变量"""
|
RealNumber: TypeAlias = int | float
|
||||||
VAR: TypeAlias = float | Iterable[float] # 为后期支持多维矢量化做准备
|
Number: TypeAlias = RealNumber | complex
|
||||||
|
SingleVar = TypeVar("SingleVar", bound=Number)
|
||||||
|
ArrayVar = TypeVar("ArrayVar", bound=Iterable[Number])
|
||||||
|
Var: TypeAlias = SingleVar | ArrayVar
|
||||||
|
|
||||||
ONE_VARIABLE_FUNC: TypeAlias = Callable[[VAR], float]
|
OneSingleVarFunc: TypeAlias = Callable[[SingleVar], SingleVar]
|
||||||
TWO_VARIABLES_FUNC: TypeAlias = Callable[[VAR, VAR], float]
|
OneArrayFunc: TypeAlias = Callable[[ArrayVar], ArrayVar]
|
||||||
THREE_VARIABLES_FUNC: TypeAlias = Callable[[VAR, VAR, VAR], float]
|
OneVarFunc: TypeAlias = OneSingleVarFunc | OneArrayFunc
|
||||||
|
|
||||||
|
TwoSingleVarFunc: TypeAlias = Callable[[SingleVar, SingleVar], SingleVar]
|
||||||
|
TwoArrayFunc: TypeAlias = Callable[[ArrayVar, ArrayVar], ArrayVar]
|
||||||
|
TwoVarFunc: TypeAlias = TwoSingleVarFunc | TwoArrayFunc
|
||||||
|
|
||||||
|
ThreeSingleVarFunc: TypeAlias = Callable[[SingleVar, SingleVar, SingleVar], SingleVar]
|
||||||
|
ThreeArrayFunc: TypeAlias = Callable[[ArrayVar, ArrayVar, ArrayVar], ArrayVar]
|
||||||
|
ThreeVarFunc: TypeAlias = ThreeSingleVarFunc | ThreeArrayFunc
|
||||||
|
|
||||||
|
MultiSingleVarFunc: TypeAlias = Callable[..., SingleVar]
|
||||||
|
MultiArrayFunc: TypeAlias = Callable[..., ArrayVar]
|
||||||
|
MultiVarFunc: TypeAlias = MultiSingleVarFunc | MultiArrayFunc
|
||||||
|
@ -8,6 +8,9 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
|
|||||||
@File : utils.py
|
@File : utils.py
|
||||||
@Software: PyCharm
|
@Software: PyCharm
|
||||||
"""
|
"""
|
||||||
|
from typing import overload
|
||||||
|
|
||||||
|
from mbcp.mp_math.mp_math_typing import RealNumber
|
||||||
|
|
||||||
|
|
||||||
def clamp(x: float, min_: float, max_: float) -> float:
|
def clamp(x: float, min_: float, max_: float) -> float:
|
||||||
@ -22,3 +25,34 @@ def clamp(x: float, min_: float, max_: float) -> float:
|
|||||||
限制后的值
|
限制后的值
|
||||||
"""
|
"""
|
||||||
return max(min(x, max_), min_)
|
return max(min(x, max_), min_)
|
||||||
|
|
||||||
|
|
||||||
|
class Approx(float):
|
||||||
|
"""
|
||||||
|
用于近似比较浮点数的类。
|
||||||
|
"""
|
||||||
|
epsilon = 0.001
|
||||||
|
"""全局近似值。"""
|
||||||
|
|
||||||
|
def __new__(cls, x: RealNumber):
|
||||||
|
return super().__new__(cls, x)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return abs(self - other) < Approx.epsilon
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
|
||||||
|
def approx(x: float, y: float = 0.0, epsilon: float = 0.0001) -> bool:
|
||||||
|
"""
|
||||||
|
判断两个数是否近似相等。或包装一个实数,用于判断是否近似于0。
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
y:
|
||||||
|
epsilon:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否近似相等
|
||||||
|
"""
|
||||||
|
return abs(x - y) < epsilon
|
||||||
|
3
tests/pytest.ini
Normal file
3
tests/pytest.ini
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[pytest]
|
||||||
|
log_cli = true
|
||||||
|
log_level = INFO
|
86
tests/test_partial_derivative.py
Normal file
86
tests/test_partial_derivative.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
偏导测试
|
||||||
|
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from mbcp.mp_math.mp_math_typing import RealNumber
|
||||||
|
|
||||||
|
|
||||||
|
def three_var_func(x: RealNumber, y: RealNumber) -> RealNumber:
|
||||||
|
return x ** 3 * y ** 2 - 3 * x * y ** 3 - x * y + 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestPartialDerivative:
|
||||||
|
# 样例来源:同济大学《高等数学》第八版下册 第九章第二节 例6
|
||||||
|
def test_2v_1o_1v(self):
|
||||||
|
"""测试二元函数关于第一个变量(x)的一阶偏导 df/dx"""
|
||||||
|
|
||||||
|
from mbcp.mp_math.utils import Approx
|
||||||
|
from mbcp.mp_math.equation import get_partial_derivative_func
|
||||||
|
|
||||||
|
partial_derivative_func = get_partial_derivative_func(three_var_func, (0,))
|
||||||
|
|
||||||
|
# assert partial_derivative_func(1, 2, 3) == 4.0
|
||||||
|
def df_dx(x, y):
|
||||||
|
"""原函数关于x的偏导"""
|
||||||
|
return 3 * (x ** 2) * (y ** 2) - 3 * (y ** 3) - y
|
||||||
|
logging.info(f"Expected: {df_dx(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
|
assert Approx(partial_derivative_func(1, 2)) == df_dx(1, 2)
|
||||||
|
|
||||||
|
def test_2v_1o_2v(self):
|
||||||
|
"""测试二元函数关于第二个变量(y)的一阶偏导 df/dy"""
|
||||||
|
|
||||||
|
from mbcp.mp_math.utils import Approx
|
||||||
|
from mbcp.mp_math.equation import get_partial_derivative_func
|
||||||
|
|
||||||
|
partial_derivative_func = get_partial_derivative_func(three_var_func, 1)
|
||||||
|
|
||||||
|
def df_dy(x, y):
|
||||||
|
"""原函数关于y的偏导"""
|
||||||
|
return 2 * (x ** 3) * y - 9 * x * (y ** 2) - x
|
||||||
|
logging.info(f"Expected: {df_dy(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
|
assert Approx(partial_derivative_func(1, 2)) == df_dy(1, 2)
|
||||||
|
|
||||||
|
def test_2v_2o_12v(self):
|
||||||
|
"""高阶偏导d^2f/(dxdy)"""
|
||||||
|
|
||||||
|
from mbcp.mp_math.utils import Approx
|
||||||
|
from mbcp.mp_math.equation import get_partial_derivative_func
|
||||||
|
|
||||||
|
partial_derivative_func = get_partial_derivative_func(three_var_func, (0, 1))
|
||||||
|
|
||||||
|
def df_dxdy(x, y):
|
||||||
|
"""原函数关于y和x的偏导"""
|
||||||
|
return 6 * x ** 2 * y - 9 * y ** 2 - 1
|
||||||
|
logging.info(f"Expected: {df_dxdy(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
|
assert Approx(partial_derivative_func(1, 2)) == df_dxdy(1, 2)
|
||||||
|
|
||||||
|
def test_2v_2o_1v2(self):
|
||||||
|
"""二阶偏导d^2f/(dx^2)"""
|
||||||
|
|
||||||
|
from mbcp.mp_math.utils import Approx
|
||||||
|
from mbcp.mp_math.equation import get_partial_derivative_func
|
||||||
|
|
||||||
|
partial_derivative_func = get_partial_derivative_func(three_var_func, (0 , 0))
|
||||||
|
|
||||||
|
def df_dydx(x, y):
|
||||||
|
"""原函数关于x和y的偏导"""
|
||||||
|
return 6 * x * y ** 2
|
||||||
|
logging.info(f"Expected: {df_dydx(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
|
assert Approx(partial_derivative_func(1, 2)) == df_dydx(1, 2)
|
||||||
|
|
||||||
|
def test_2v_3o_1v3(self):
|
||||||
|
"""高阶偏导d^3f/(dx^3)"""
|
||||||
|
|
||||||
|
from mbcp.mp_math.utils import Approx
|
||||||
|
from mbcp.mp_math.equation import get_partial_derivative_func
|
||||||
|
|
||||||
|
partial_derivative_func = get_partial_derivative_func(three_var_func, (0, 0, 0))
|
||||||
|
|
||||||
|
def d3f_dx3(x, y):
|
||||||
|
"""原函数关于x的三阶偏导"""
|
||||||
|
return 6 * (y ** 2)
|
||||||
|
logging.info(f"Expected: {d3f_dx3(1, 2)}, Actual: {partial_derivative_func(1, 2)}")
|
||||||
|
assert Approx(partial_derivative_func(1, 2)) == d3f_dx3(1, 2)
|
Loading…
Reference in New Issue
Block a user