Entscheidungsbaum
Wie Entscheidungsbäume funktionieren, Klassifikations- und Regressionsbäume in Python mit scikit-learn erstellen, Hyperparameter abstimmen und visualisieren.
Ein Entscheidungsbaum ist ein überwachter Machine-Learning-Algorithmus, der Vorhersagen trifft, indem er eine Hierarchie von Wenn-dann-sonst-Regeln aus Trainingsdaten lernt. Jeder interne Knoten prüft ein Merkmal, jeder Zweig stellt ein Ergebnis dieses Tests dar, und jedes Blatt enthält eine Vorhersage (ein Klassenlabel bei der Klassifikation oder einen numerischen Wert bei der Regression).
Dieses Kapitel behandelt:
- Wie Entscheidungsbäume Daten mithilfe von Unreinheitsmaßen (Gini und Entropie) aufteilen
- Den Aufbau eines Klassifikationsbaums und eines Regressionsbaums in Python mit
scikit-learn - Die Steuerung der Baumtiefe und die Vermeidung von Overfitting durch Hyperparameter
- Die Visualisierung und Inspektion eines trainierten Baums
- Vorteile, Einschränkungen und Einsatzszenarien von Entscheidungsbäumen
Wie ein Entscheidungsbaum Daten aufteilt
Beim Training durchsucht der Algorithmus jedes Merkmal und jeden möglichen Schwellenwert, um die Aufteilung zu finden, die die Unreinheit am stärksten reduziert — ein Maß dafür, wie gemischt die Klassen in einem Knoten sind.
In scikit-learn sind zwei Unreinheitsmaße verbreitet:
Gini-Unreinheit
Die Gini-Unreinheit misst die Wahrscheinlichkeit, eine zufällig gewählte Stichprobe falsch zu klassifizieren, wenn sie gemäß der Klassenverteilung im Knoten beschriftet würde.
Gini(node) = 1 - Σ pᵢ²Ein reiner Knoten (alle Stichproben gehören einer Klasse an) hat Gini = 0. Ein maximal gemischter Knoten hat einen Gini-Wert nahe 0,5 bei binärer Klassifikation.
Entropie und Informationsgewinn
Die Entropie stammt aus der Informationstheorie. Sie ist maximal, wenn die Klassen gleichmäßig verteilt sind, und null, wenn der Knoten rein ist.
Entropy(node) = -Σ pᵢ log₂(pᵢ)Der Informationsgewinn ist der Rückgang der Entropie nach einer Aufteilung. Der Algorithmus wählt die Aufteilung, die den größten Informationsgewinn liefert. In scikit-learn wählen Sie zwischen beiden über den Parameter criterion ("gini" ist der Standard).
Rekursive Aufteilung
Die Aufteilung wird rekursiv für jeden Kindknoten wiederholt, bis eine Abbruchbedingung erfüllt ist: Der Knoten ist rein, kein Merkmal verbessert die Unreinheit, oder ein Tiefen-/Größenlimit ist erreicht. Dies erzeugt die binäre Baumstruktur.
Klassifikationsbaum in Python
Der Iris-Datensatz hat 150 Stichproben und 4 numerische Merkmale. Ziel ist es, eine von drei Blumenarten vorherzusagen.
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
# Load dataset
data = load_iris()
X, y = data.data, data.target
# Split: 80 % train, 20 % test
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Train — limit depth to 3 to keep the tree readable
clf = DecisionTreeClassifier(criterion="gini", max_depth=3, random_state=42)
clf.fit(X_train, y_train)
# Evaluate
y_pred = clf.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.2f}")
print(classification_report(y_test, y_pred, target_names=data.target_names))Erwartete Ausgabe:
Accuracy: 1.00
precision recall f1-score support
setosa 1.00 1.00 1.00 10
versicolor 1.00 1.00 1.00 9
virginica 1.00 1.00 1.00 11
accuracy 1.00 30
macro avg 1.00 1.00 1.00 30
weighted avg 1.00 1.00 1.00 30Der Iris-Datensatz ist mit Tiefe 3 linear trennbar, sodass der Baum eine perfekte Testgenauigkeit erreicht. Reale Datensätze werden unordentlicher sein.
Neue Stichproben vorhersagen
Nach dem Training rufen Sie predict() auf, um neue Beobachtungen zu klassifizieren, und predict_proba(), um Klassenwahrscheinlichkeiten zu erhalten:
import numpy as np
# A new flower: sepal length 5.1, sepal width 3.5, petal length 1.4, petal width 0.2
new_sample = np.array([[5.1, 3.5, 1.4, 0.2]])
predicted_class = clf.predict(new_sample)
predicted_proba = clf.predict_proba(new_sample)
print("Predicted class:", data.target_names[predicted_class[0]])
print("Class probabilities:", predicted_proba)Erwartete Ausgabe:
Predicted class: setosa
Class probabilities: [[1. 0. 0.]]Regressionsbaum in Python
Entscheidungsbäume können auch kontinuierliche Zielwerte verarbeiten. Verwenden Sie statt DecisionTreeClassifier den DecisionTreeRegressor.
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import numpy as np
# Synthetic regression dataset
X_reg, y_reg = make_regression(
n_samples=300, n_features=5, noise=20, random_state=42
)
X_train_r, X_test_r, y_train_r, y_test_r = train_test_split(
X_reg, y_reg, test_size=0.2, random_state=42
)
reg = DecisionTreeRegressor(max_depth=5, random_state=42)
reg.fit(X_train_r, y_train_r)
y_pred_r = reg.predict(X_test_r)
mse = mean_squared_error(y_test_r, y_pred_r)
r2 = r2_score(y_test_r, y_pred_r)
print(f"MSE : {mse:.2f}")
print(f"R² : {r2:.2f}")Ein Regressionsbaum teilt durch Minimierung des mittleren quadratischen Fehlers (MSE) innerhalb jedes Knotens auf und sagt den mittleren Zielwert aller Trainingsstichproben vorher, die ein Blatt erreichen.
Hyperparameter abstimmen
Ohne Einschränkungen wächst ein Entscheidungsbaum, bis jedes Blatt rein ist, und memoriert so den Trainingssatz perfekt (Overfitting). Hyperparameter steuern die Baumkomplexität:
| Parameter | Standard | Wirkung |
|---|---|---|
max_depth | None | Maximale Anzahl von Ebenen. Niedriger = einfacherer Baum. |
min_samples_split | 2 | Mindestanzahl an Stichproben, um einen Knoten aufzuteilen. Höher = weniger Aufteilungen. |
min_samples_leaf | 1 | Mindestanzahl an Stichproben in einem Blatt. Höher = glattere Grenzen. |
max_features | None | Anzahl der bei jeder Aufteilung zu berücksichtigenden Merkmale (nützlich für die Merkmalsauswahl). |
criterion | "gini" | Unreinheitsmaß: "gini" oder "entropy" für Klassifikatoren; "squared_error" für Regressoren. |
Verwenden Sie Kreuzvalidierung und Grid Search, um die beste Kombination zu finden:
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV
data = load_iris()
X, y = data.data, data.target
param_grid = {
"max_depth": [2, 3, 4, 5, None],
"min_samples_split": [2, 5, 10],
"criterion": ["gini", "entropy"],
}
grid_search = GridSearchCV(
DecisionTreeClassifier(random_state=42),
param_grid,
cv=5,
scoring="accuracy",
)
grid_search.fit(X, y)
print("Best params :", grid_search.best_params_)
print(f"Best CV score: {grid_search.best_score_:.3f}")Erwartete Ausgabe (Werte können je nach scikit-learn-Version leicht variieren):
Best params : {'criterion': 'gini', 'max_depth': 3, 'min_samples_split': 2}
Best CV score: 0.973Kategoriale Merkmale verarbeiten
scikit-learn-Entscheidungsbäume erfordern numerische Eingaben. Kodieren Sie kategoriale Spalten vor dem Training:
- Ordinale Kategorien (z. B. Größe: klein < mittel < groß): verwenden Sie
OrdinalEncoder. - Nominale Kategorien (z. B. Farbe: rot, grün, blau): verwenden Sie
OneHotEncoder, um keine Reihenfolge zu implizieren.
from sklearn.preprocessing import OrdinalEncoder
import numpy as np
# Encode only the categorical column; keep the numeric column as-is
sizes = np.array([["small"], ["large"], ["medium"], ["large"]])
weights = np.array([1.2, 3.4, 2.1, 4.0])
# Explicit category order: large=0, medium=1, small=2
enc = OrdinalEncoder(categories=[["large", "medium", "small"]])
sizes_encoded = enc.fit_transform(sizes)
X_encoded = np.column_stack([sizes_encoded, weights])
print(X_encoded)Erwartete Ausgabe:
[[2. 1.2]
[0. 3.4]
[1. 2.1]
[0. 4. ]]Weitere Details finden Sie im Kapitel Kategoriale Daten.
Einen Entscheidungsbaum visualisieren
Die Inspektion der Baumstruktur zeigt, welche Merkmale die meisten Aufteilungen treiben, und macht das Modell nachvollziehbar.
Textdarstellung
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.datasets import load_iris
data = load_iris()
clf = DecisionTreeClassifier(max_depth=2, random_state=42)
clf.fit(data.data, data.target)
print(export_text(clf, feature_names=list(data.feature_names)))Erwartete Ausgabe:
|--- petal length (cm) <= 2.45
| |--- class: 0
|--- petal length (cm) > 2.45
| |--- petal width (cm) <= 1.75
| | |--- class: 1
| |--- petal width (cm) > 1.75
| | |--- class: 2Grafische Darstellung
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
data = load_iris()
clf = DecisionTreeClassifier(max_depth=2, random_state=42)
clf.fit(data.data, data.target)
plt.figure(figsize=(10, 5))
plot_tree(
clf,
feature_names=data.feature_names,
class_names=data.target_names,
filled=True,
rounded=True,
)
plt.title("Iris Decision Tree (max_depth=2)")
plt.tight_layout()
plt.savefig("iris_tree.png", dpi=150)
plt.show()filled=True färbt jeden Knoten nach seiner Mehrheitsklasse; dunklere Farbtöne bedeuten höhere Klassenreinheit.
Merkmalswichtigkeit
Nach dem Training gibt feature_importances_ jedem Merkmal eine Punktzahl zwischen 0 und 1, wobei ein höherer Wert bedeutet, dass das Merkmal stärker zur Reduktion der Unreinheit über alle Aufteilungen beigetragen hat:
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
import numpy as np
data = load_iris()
clf = DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(data.data, data.target)
importances = clf.feature_importances_
for name, imp in sorted(
zip(data.feature_names, importances), key=lambda x: x[1], reverse=True
):
print(f"{name:30s}: {imp:.4f}")Erwartete Ausgabe:
petal length (cm) : 0.5856
petal width (cm) : 0.4144
sepal length (cm) : 0.0000
sepal width (cm) : 0.0000Merkmale mit einer Wichtigkeit von 0 wurden von keiner Aufteilung verwendet und könnten entfernt werden, um das Modell zu vereinfachen.
Vorteile und Einschränkungen
Wann Entscheidungsbäume einsetzen
- Sie benötigen ein interpretierbares Modell — die Regeln können als Klartext ausgegeben werden.
- Ihr Datensatz enthält eine Mischung aus numerischen und kategorialen Merkmalen (nach der Kodierung).
- Sie möchten schnell eine Basislinie erstellen, bevor Sie Ensemble-Methoden ausprobieren.
- Die Beziehung zwischen Merkmalen und Zielwert ist nicht-linear oder beinhaltet Wechselwirkungen.
Einschränkungen
| Einschränkung | Abhilfe |
|---|---|
| Overfittet leicht ohne Abstimmung | max_depth, min_samples_leaf einschränken; Kreuzvalidierung verwenden |
| Hohe Varianz (kleine Datenänderungen → anderer Baum) | Ensemble-Methoden verwenden: Random Forest / Bootstrap Aggregation |
| Bevorzugt Merkmale mit mehr eindeutigen Werten | max_features verwenden oder Split-Kriterien normalisieren |
| Schlecht beim Extrapolieren über den Trainingsbereich hinaus | Lineare Modelle für Extrapolationsaufgaben bevorzugen |
| Nur achsenparallele Aufteilungen | Oblique Trees existieren, sind aber nicht in scikit-learn enthalten |
Entscheidungsbäume vs. verwandte Algorithmen
| Algorithmus | Wesentlicher Unterschied |
|---|---|
| Logistische Regression | Lineare Grenze; besser für linear trennbare Daten; verarbeitet Wechselwirkungen nicht automatisch |
| K-Nearest Neighbors | Instanzbasiert; kein explizites Modell; erfordert Merkmalsskalierung |
| Entscheidungsbaum | Nicht-linear; keine Skalierung erforderlich; gut interpretierbar |
| Random Forest (siehe Bootstrap Aggregation) | Ensemble vieler Bäume; viel geringere Varianz; weniger interpretierbar |
Wichtigste Erkenntnisse
- Entscheidungsbäume teilen Daten auf, indem sie den Informationsgewinn maximieren (oder die Gini-Unreinheit minimieren) an jedem Knoten; der Prozess wiederholt sich rekursiv.
DecisionTreeClassifierundDecisionTreeRegressorin scikit-learn teilen dieselbe API und dieselben Hyperparameter-Namen.- Legen Sie immer
max_depthodermin_samples_leaffest, um Overfitting zu verhindern; stimmen Sie diese mit Grid Search und Kreuzvalidierung ab. feature_importances_zeigt, auf welche Merkmale der Baum am meisten angewiesen ist — nützlich für die Merkmalsauswahl.- Einzelne Bäume sind eine gute interpretierbare Basislinie, aber Ensemble-Methoden wie Random Forest übertreffen sie bei realen Daten fast immer.