-
-
Notifications
You must be signed in to change notification settings - Fork 46k
/
dual_number_automatic_differentiation.py
141 lines (120 loc) · 4.45 KB
/
dual_number_automatic_differentiation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from math import factorial
"""
https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers
https://blog.jliszka.org/2013/10/24/exact-numeric-nth-derivatives.html
Note this only works for basic functions, f(x) where the power of x is positive.
"""
class Dual:
def __init__(self, real, rank):
self.real = real
if isinstance(rank, int):
self.duals = [1] * rank
else:
self.duals = rank
def __repr__(self):
return (
f"{self.real}+"
f"{'+'.join(str(dual)+'E'+str(n+1)for n,dual in enumerate(self.duals))}"
)
def reduce(self):
cur = self.duals.copy()
while cur[-1] == 0:
cur.pop(-1)
return Dual(self.real, cur)
def __add__(self, other):
if not isinstance(other, Dual):
return Dual(self.real + other, self.duals)
s_dual = self.duals.copy()
o_dual = other.duals.copy()
if len(s_dual) > len(o_dual):
o_dual.extend([1] * (len(s_dual) - len(o_dual)))
elif len(s_dual) < len(o_dual):
s_dual.extend([1] * (len(o_dual) - len(s_dual)))
new_duals = []
for i in range(len(s_dual)):
new_duals.append(s_dual[i] + o_dual[i])
return Dual(self.real + other.real, new_duals)
__radd__ = __add__
def __sub__(self, other):
return self + other * -1
def __mul__(self, other):
if not isinstance(other, Dual):
new_duals = []
for i in self.duals:
new_duals.append(i * other)
return Dual(self.real * other, new_duals)
new_duals = [0] * (len(self.duals) + len(other.duals) + 1)
for i, item in enumerate(self.duals):
for j, jtem in enumerate(other.duals):
new_duals[i + j + 1] += item * jtem
for k in range(len(self.duals)):
new_duals[k] += self.duals[k] * other.real
for index in range(len(other.duals)):
new_duals[index] += other.duals[index] * self.real
return Dual(self.real * other.real, new_duals)
__rmul__ = __mul__
def __truediv__(self, other):
if not isinstance(other, Dual):
new_duals = []
for i in self.duals:
new_duals.append(i / other)
return Dual(self.real / other, new_duals)
raise ValueError
def __floordiv__(self, other):
if not isinstance(other, Dual):
new_duals = []
for i in self.duals:
new_duals.append(i // other)
return Dual(self.real // other, new_duals)
raise ValueError
def __pow__(self, n):
if n < 0 or isinstance(n, float):
raise ValueError("power must be a positive integer")
if n == 0:
return 1
if n == 1:
return self
x = self
for _ in range(n - 1):
x *= self
return x
def differentiate(func, position, order):
"""
>>> differentiate(lambda x: x**2, 2, 2)
2
>>> differentiate(lambda x: x**2 * x**4, 9, 2)
196830
>>> differentiate(lambda y: 0.5 * (y + 3) ** 6, 3.5, 4)
7605.0
>>> differentiate(lambda y: y ** 2, 4, 3)
0
>>> differentiate(8, 8, 8)
Traceback (most recent call last):
...
ValueError: differentiate() requires a function as input for func
>>> differentiate(lambda x: x **2, "", 1)
Traceback (most recent call last):
...
ValueError: differentiate() requires a float as input for position
>>> differentiate(lambda x: x**2, 3, "")
Traceback (most recent call last):
...
ValueError: differentiate() requires an int as input for order
"""
if not callable(func):
raise ValueError("differentiate() requires a function as input for func")
if not isinstance(position, (float, int)):
raise ValueError("differentiate() requires a float as input for position")
if not isinstance(order, int):
raise ValueError("differentiate() requires an int as input for order")
d = Dual(position, 1)
result = func(d)
if order == 0:
return result.real
return result.duals[order - 1] * factorial(order)
if __name__ == "__main__":
import doctest
doctest.testmod()
def f(y):
return y**2 * y**4
print(differentiate(f, 9, 2))