Please note, this is a STATIC archive of website www.tutorialspoint.com from 11 May 2019, cach3.com does not collect or store any user information, there is no "phishing" involved.
Tutorialspoint

Deep Learning in Python-Chapter 2

# =============================================================================
# The need for optimization
# =============================================================================

''' Scaling up to multiple data points '''

from sklearn.metrics import mean_squared_error

# Create model_output_0 
model_output_0 = []
# Create model_output_1
model_output_1 = []

# Loop over input_data
for row in input_data:
    # Append prediction to model_output_0
    model_output_0.append(predict_with_network(row, weights_0))
    
    # Append prediction to model_output_1
    model_output_1.append(predict_with_network(row, weights_1))

# Calculate the mean squared error for model_output_0: mse_0
mse_0 = mean_squared_error(target_actuals, model_output_0)

# Calculate the mean squared error for model_output_1: mse_1
mse_1 = mean_squared_error(target_actuals, model_output_1)

# Print mse_0 and mse_1
print("Mean squared error with weights_0: %f" %mse_0)
print("Mean squared error with weights_1: %f" %mse_1)

# =============================================================================
# Gradient descent
# =============================================================================

# Given function: get_slope(input_data, target, weights), get_mse(input_data, target, weights)

n_updates = 20
mse_hist = []   # mse history

# Iterate over the number of updates
for i in range(n_updates):
    # Calculate the slope: slope
    slope = get_slope(input_data, target, weights)
    
    # Update the weights: weights
    weights = weights - 0.01 * slope
    
    # Calculate mse with new weights: mse
    mse = get_mse(input_data, target, weights)
    
    # Append the mse to mse_hist
    mse_hist.append(mse)

# Plot the mse history
plt.plot(mse_hist)
plt.xlabel('Iterations')
plt.ylabel('Mean Squared Error')
plt.show()

# =============================================================================
# Backpropagation
# =============================================================================

# 沒有程式碼 XD

Advertisements
Loading...

We use cookies to provide and improve our services. By using our site, you consent to our Cookies Policy.