
- ML - Home
- ML - Introduction
- ML - Getting Started
- ML - Basic Concepts
- ML - Ecosystem
- ML - Python Libraries
- ML - Applications
- ML - Life Cycle
- ML - Required Skills
- ML - Implementation
- ML - Challenges & Common Issues
- ML - Limitations
- ML - Reallife Examples
- ML - Data Structure
- ML - Mathematics
- ML - Artificial Intelligence
- ML - Neural Networks
- ML - Deep Learning
- ML - Getting Datasets
- ML - Categorical Data
- ML - Data Loading
- ML - Data Understanding
- ML - Data Preparation
- ML - Models
- ML - Supervised Learning
- ML - Unsupervised Learning
- ML - Semi-supervised Learning
- ML - Reinforcement Learning
- ML - Supervised vs. Unsupervised
- Machine Learning Data Visualization
- ML - Data Visualization
- ML - Histograms
- ML - Density Plots
- ML - Box and Whisker Plots
- ML - Correlation Matrix Plots
- ML - Scatter Matrix Plots
- Statistics for Machine Learning
- ML - Statistics
- ML - Mean, Median, Mode
- ML - Standard Deviation
- ML - Percentiles
- ML - Data Distribution
- ML - Skewness and Kurtosis
- ML - Bias and Variance
- ML - Hypothesis
- Regression Analysis In ML
- ML - Regression Analysis
- ML - Linear Regression
- ML - Simple Linear Regression
- ML - Multiple Linear Regression
- ML - Polynomial Regression
- Classification Algorithms In ML
- ML - Classification Algorithms
- ML - Logistic Regression
- ML - K-Nearest Neighbors (KNN)
- ML - Naïve Bayes Algorithm
- ML - Decision Tree Algorithm
- ML - Support Vector Machine
- ML - Random Forest
- ML - Confusion Matrix
- ML - Stochastic Gradient Descent
- Clustering Algorithms In ML
- ML - Clustering Algorithms
- ML - Centroid-Based Clustering
- ML - K-Means Clustering
- ML - K-Medoids Clustering
- ML - Mean-Shift Clustering
- ML - Hierarchical Clustering
- ML - Density-Based Clustering
- ML - DBSCAN Clustering
- ML - OPTICS Clustering
- ML - HDBSCAN Clustering
- ML - BIRCH Clustering
- ML - Affinity Propagation
- ML - Distribution-Based Clustering
- ML - Agglomerative Clustering
- Dimensionality Reduction In ML
- ML - Dimensionality Reduction
- ML - Feature Selection
- ML - Feature Extraction
- ML - Backward Elimination
- ML - Forward Feature Construction
- ML - High Correlation Filter
- ML - Low Variance Filter
- ML - Missing Values Ratio
- ML - Principal Component Analysis
- Reinforcement Learning
- ML - Reinforcement Learning Algorithms
- ML - Exploitation & Exploration
- ML - Q-Learning
- ML - REINFORCE Algorithm
- ML - SARSA Reinforcement Learning
- ML - Actor-critic Method
- ML - Monte Carlo Methods
- ML - Temporal Difference
- Deep Reinforcement Learning
- ML - Deep Reinforcement Learning
- ML - Deep Reinforcement Learning Algorithms
- ML - Deep Q-Networks
- ML - Deep Deterministic Policy Gradient
- ML - Trust Region Methods
- Quantum Machine Learning
- ML - Quantum Machine Learning
- ML - Quantum Machine Learning with Python
- Machine Learning Miscellaneous
- ML - Performance Metrics
- ML - Automatic Workflows
- ML - Boost Model Performance
- ML - Gradient Boosting
- ML - Bootstrap Aggregation (Bagging)
- ML - Cross Validation
- ML - AUC-ROC Curve
- ML - Grid Search
- ML - Data Scaling
- ML - Train and Test
- ML - Association Rules
- ML - Apriori Algorithm
- ML - Gaussian Discriminant Analysis
- ML - Cost Function
- ML - Bayes Theorem
- ML - Precision and Recall
- ML - Adversarial
- ML - Stacking
- ML - Epoch
- ML - Perceptron
- ML - Regularization
- ML - Overfitting
- ML - P-value
- ML - Entropy
- ML - MLOps
- ML - Data Leakage
- ML - Monetizing Machine Learning
- ML - Types of Data
- Machine Learning - Resources
- ML - Quick Guide
- ML - Cheatsheet
- ML - Interview Questions
- ML - Useful Resources
- ML - Discussion
Machine Learning - Overfitting
Overfitting occurs when a model learns the noise in the training data, rather than the underlying patterns. This causes the model to perform well on the training data, but poorly on new data. Essentially, the model becomes too specialized to the training data, and is unable to generalize to new data.
Overfitting is a common problem when using complex models, such as deep neural networks. These models have many parameters, and are able to fit the training data very closely. However, this often comes at the expense of generalization performance.
Causes of Overfitting
There are several factors that can contribute to overfitting −
Complex models − As mentioned earlier, complex models are more likely to overfit than simpler models. This is because they have more parameters, and are able to fit the training data more closely.
Limited training data − When there is not enough training data, it becomes difficult for the model to learn the underlying patterns, and it may instead learn the noise in the data.
Unrepresentative training data − If the training data is not representative of the problem that the model is trying to solve, the model may learn irrelevant patterns that do not generalize well to new data.
Lack of regularization − Regularization is a technique used to prevent overfitting by adding a penalty term to the cost function. If this penalty term is not present, the model is more likely to overfit.
Techniques to Prevent Overfitting
There are several techniques that can be used to prevent overfitting in machine learning −
Cross-validation − Cross-validation is a technique used to evaluate a model's performance on new, unseen data. It involves dividing the data into several subsets, and using each subset in turn as a validation set, while training on the remaining data. This helps to ensure that the model generalizes well to new data.
Early stopping − Early stopping is a technique used to prevent a model from overfitting by stopping the training process before it has converged completely. This is done by monitoring the validation error during training, and stopping when the error stops improving.
Regularization − Regularization is a technique used to prevent overfitting by adding a penalty term to the cost function. The penalty term encourages the model to have smaller weights, and helps to prevent it from fitting the noise in the training data.
Dropout − Dropout is a technique used in deep neural networks to prevent overfitting. It involves randomly dropping out some of the neurons during training, which forces the remaining neurons to learn more robust features.
Example
Here is an implementation of early stopping and L2 regularization in Python using Keras −
from keras.models import Sequential from keras.layers import Dense from keras.callbacks import EarlyStopping from keras import regularizers # define the model architecture model = Sequential() model.add(Dense(64, input_dim=X_train.shape[1], activation='relu', kernel_regularizer=regularizers.l2(0.01))) model.add(Dense(32, activation='relu', kernel_regularizer=regularizers.l2(0.01))) model.add(Dense(1, activation='sigmoid')) # compile the model model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # set up early stopping callback early_stopping = EarlyStopping(monitor='val_loss', patience=5) # train the model with early stopping and L2 regularization history = model.fit(X_train, y_train, validation_split=0.2, epochs=100, batch_size=64, callbacks=[early_stopping])
In this code, we have used the Sequential model in Keras to define the model architecture, and we have added L2 regularization to the first two layers using the kernel_regularizer argument. We have also set up an early stopping callback using the EarlyStopping class in Keras, which will monitor the validation loss and stop training if it stops improving for 5 epochs.
During training, we pass in the X_train and y_train data as well as a validation split of 0.2 to monitor the validation loss. We also set a batch size of 64 and train for a maximum of 100 epochs.
Output
When you execute this code, it will produce an output like the one shown below −
Train on 323 samples, validate on 81 samples Epoch 1/100 323/323 [==============================] - 0s 792us/sample - loss: -8.9033 - accuracy: 0.0000e+00 - val_loss: -15.1467 - val_accuracy: 0.0000e+00 Epoch 2/100 323/323 [==============================] - 0s 46us/sample - loss: -20.4505 - accuracy: 0.0000e+00 - val_loss: -25.7619 - val_accuracy: 0.0000e+00 Epoch 3/100 323/323 [==============================] - 0s 43us/sample - loss: -31.9206 - accuracy: 0.0000e+00 - val_loss: -36.8155 - val_accuracy: 0.0000e+00 Epoch 4/100 323/323 [==============================] - 0s 46us/sample - loss: -44.2281 - accuracy: 0.0000e+00 - val_loss: -49.0378 - val_accuracy: 0.0000e+00 Epoch 5/100 323/323 [==============================] - 0s 52us/sample - loss: -58.3326 - accuracy: 0.0000e+00 - val_loss: -62.9369 - val_accuracy: 0.0000e+00 Epoch 6/100 323/323 [==============================] - 0s 40us/sample - loss: -74.2131 - accuracy: 0.0000e+00 - val_loss: -78.7068 - val_accuracy: 0.0000e+00 -----continue
By using early stopping and L2 regularization, we can help prevent overfitting and improve the generalization performance of our model.