-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtraining.py
More file actions
80 lines (59 loc) · 2.22 KB
/
training.py
File metadata and controls
80 lines (59 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import matplotlib.pyplot as plt
def estimated_price(theta0, theta1, mileage):
return theta0 + (theta1 * mileage)
def calculate_error(est_price, price):
return est_price - price
def calculate_correction_theta0(learning_rate, m, errors):
return learning_rate * (1 / m) * sum(errors)
def calculate_correction_theta1(learning_rate, m, errors, mileages):
return learning_rate * (1 / m) * sum(errors * mileages)
def learning_loop(mileage, price):
mileages_normalized = mileage / 100000
prices_normalized = price / 1000
learning_rate = 0.05
len_data = len(mileage)
theta0 = 0
theta1 = 0
while True:
est_prices = []
errors = []
for m in mileages_normalized:
est_prices.append(estimated_price(theta0, theta1, m))
for ep, p in zip(est_prices, prices_normalized):
errors.append(calculate_error(ep, p))
tmp_theta0 = calculate_correction_theta0(learning_rate, len_data, errors)
tmp_theta1 = calculate_correction_theta1(learning_rate, len_data, errors, mileages_normalized)
theta0 = theta0 - tmp_theta0
theta1 = theta1 - tmp_theta1
if abs(tmp_theta0) < 0.0001 and abs(tmp_theta1) < 0.005:
break
theta0_denormalized = theta0 * 1000
theta1_denormalized = theta1 * 0.01
return [theta0_denormalized, theta1_denormalized]
def save_thetas(theta0, theta1):
file = open("thetas.txt", "w")
file.write(f"{theta0}\n")
file.write(f"{theta1}")
def display_data(mileage, price, theta0, theta1):
est_prices = []
for m in mileage:
est_prices.append(estimated_price(theta0, theta1, m))
plt.title("Car price prediction based on mileage")
plt.scatter(mileage, price)
plt.plot(mileage, est_prices, color="red")
plt.xlabel("Mileage")
plt.ylabel("Price")
plt.show()
def main():
try:
data = np.loadtxt("data.csv", skiprows=1, delimiter=",")
mileage = data[:, 0]
price = data[:, 1]
thetas = learning_loop(mileage, price)
save_thetas(thetas[0], thetas[1])
display_data(mileage, price, thetas[0], thetas[1])
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()