Streudiagramme in Python verstehen
Streudiagramme in Python mit Matplotlib und Seaborn erstellen und lesen. Themen: Korrelation, Farbkodierung, Regressionslinien und ML-Anwendungen.
Ein Streudiagramm setzt an jedem (x, y)-Paar eines Datensatzes einen Punkt. Die entstehende Punktwolke zeigt, ob zwei numerische Variablen miteinander zusammenhängen, wie stark und in welche Richtung — damit sind Streudiagramme unverzichtbar für die explorative Datenanalyse und Machine-Learning-Workflows.
Dieses Kapitel behandelt:
- Was Streudiagramme zeigen und wie man sie liest
- Erstellen von Streudiagrammen mit Matplotlib und Seaborn
- Anpassen von Farben, Größen und Transparenz
- Kodierung einer dritten Variable durch Farbe oder Größe (Blasendiagramme)
- Darstellung mehrerer Gruppen mit einer Legende
- Hinzufügen einer Regressions-Trendlinie
- Typische Anwendungsfälle im Machine Learning
Was ein Streudiagramm zeigt
Jeder Punkt repräsentiert eine Beobachtung. Die horizontale Achse trägt eine Variable, die vertikale Achse eine andere. Die Gesamtform der Punktwolke verrät die Korrelation zwischen den beiden Variablen.
Das Muster lesen
| Muster | Bedeutung |
|---|---|
| Punkte steigen von links nach rechts | Positive Korrelation — wenn X zunimmt, steigt tendenziell auch Y |
| Punkte fallen von links nach rechts | Negative Korrelation — wenn X zunimmt, sinkt tendenziell Y |
| Keine erkennbare Form | Keine lineare Korrelation zwischen den Variablen |
| Enge, schmale Bahn | Starke Korrelation |
| Breite, diffuse Wolke | Schwache Korrelation |
| Punkte weit von der Hauptwolke entfernt | Ausreißer — lohnt sich zu untersuchen |
Der Pearson-Korrelationskoeffizient r fasst dieses Muster als einzelne Zahl von -1 (perfekt negativ) bis +1 (perfekt positiv) zusammen. Ein Wert nahe 0 bedeutet keine lineare Beziehung. Streudiagramme ermöglichen es, das zu sehen, was r nicht vermitteln kann — zum Beispiel können zwei Datensätze denselben r-Wert teilen und dabei völlig unterschiedliche Formen haben (siehe Anscombe's Quartett).
Streudiagramm mit Matplotlib erstellen
Matplotlibs plt.scatter() ist die flexibelste Option. Installieren Sie Matplotlib, falls noch nicht geschehen:
pip install matplotlib numpyEinfaches Beispiel
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=42)
# Simulate hours studied vs exam score
hours = rng.uniform(1, 10, 40)
score = 5 * hours + rng.normal(scale=8, size=40)
plt.scatter(hours, score)
plt.xlabel('Hours Studied')
plt.ylabel('Exam Score')
plt.title('Hours Studied vs Exam Score')
plt.tight_layout()
plt.show()Der positive Anstieg in der entstehenden Punktwolke zeigt, dass mehr Lernstunden mit höheren Prüfungsergebnissen korrelieren.
Markerfarbe und -größe anpassen
Die drei nützlichsten Parameter für plt.scatter() sind:
c— Farbname, Hex-String oder Array von Werten (über eine Colormap abgebildet)s— Markergröße in Quadratpunkten (Standard 20); Skalar oder Array übergebenalpha— Transparenz von 0 (unsichtbar) bis 1 (vollständig opak); für überlappende Punkte empfehlen sich Werte von 0,4–0,7
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=7)
x = rng.normal(loc=5, scale=2, size=60)
y = rng.normal(loc=5, scale=2, size=60)
plt.scatter(x, y, c='steelblue', s=80, alpha=0.6, edgecolors='white', linewidths=0.5)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Customized Scatter Plot')
plt.tight_layout()
plt.show()edgecolors='white' zusammen mit linewidths=0.5 fügt jedem Punkt einen dünnen weißen Rahmen hinzu, sodass einzelne Punkte bei Überlappung leichter zu unterscheiden sind.
Dritte Variable mit Farbe kodieren
Übergeben Sie ein Array an c, um jeden Punkt anhand einer dritten numerischen Variable einzufärben. Fügen Sie plt.colorbar() hinzu, damit Leser wissen, was die Farben bedeuten:
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=3)
x = rng.random(50)
y = rng.random(50)
temperature = rng.uniform(15, 35, 50) # third variable, e.g. temperature in °C
scatter = plt.scatter(x, y, c=temperature, cmap='coolwarm', s=80, alpha=0.8)
plt.colorbar(scatter, label='Temperature (°C)')
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Sensor Readings by Temperature')
plt.tight_layout()
plt.show()Verwenden Sie perceptuell gleichmäßige Colormaps — 'viridis', 'plasma', 'cividis' oder 'coolwarm' — anstelle von 'jet' oder 'rainbow', die die Wahrnehmung verzerren und nicht farbenblindfreundlich sind.
Dritte Variable mit Blasengröße kodieren
Übergeben Sie ein Array an s, um die Fläche jedes Markers proportional zu einer dritten Variable zu machen — das nennt man Blasendiagramm:
import matplotlib.pyplot as plt
import numpy as np
countries = ['USA', 'China', 'Japan', 'Germany', 'UK']
gdp = [25.5, 18.0, 4.2, 4.1, 3.1] # trillion USD
life_exp = [76.4, 77.1, 84.3, 80.6, 81.3] # years
population = [334, 1412, 125, 84, 67] # millions — encoded as size
# Scale population to a visible marker area range
sizes = [p * 1.5 for p in population]
plt.scatter(gdp, life_exp, s=sizes, alpha=0.6, edgecolors='black', linewidths=0.8)
for i, name in enumerate(countries):
plt.annotate(name, (gdp[i], life_exp[i]), textcoords='offset points',
xytext=(6, 4), fontsize=9)
plt.xlabel('GDP (trillion USD)')
plt.ylabel('Life Expectancy (years)')
plt.title('GDP vs Life Expectancy (bubble size = population)')
plt.tight_layout()
plt.show()Streudiagramm mit Seaborn erstellen
Seaborns sns.scatterplot() arbeitet direkt mit Pandas DataFrames und bietet Funktionen wie automatische Gruppierung nach einer kategorischen Spalte sowie einen eingebauten hue-Parameter für die Farbkodierung.
Installieren Sie zunächst Seaborn:
pip install seaborn pandasEinfaches Seaborn-Streudiagramm
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
data = pd.DataFrame({
'hours': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'score': [45, 50, 55, 60, 65, 70, 72, 80, 85, 92],
})
sns.scatterplot(data=data, x='hours', y='score')
plt.xlabel('Hours Studied')
plt.ylabel('Exam Score')
plt.title('Hours Studied vs Exam Score')
plt.tight_layout()
plt.show()Gruppen mit hue farblich codieren
Der Parameter hue weist jeder Kategorie automatisch eine andere Farbe zu und fügt eine Legende hinzu:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
data = pd.DataFrame({
'sepal_length': [5.1, 4.9, 6.3, 5.8, 7.0, 6.4, 6.3, 5.8, 7.1, 6.3],
'sepal_width': [3.5, 3.0, 2.9, 2.7, 3.2, 3.2, 3.3, 2.7, 3.0, 2.9],
'species': ['setosa', 'setosa', 'versicolor', 'versicolor',
'virginica', 'virginica', 'virginica', 'versicolor',
'virginica', 'virginica'],
})
sns.scatterplot(data=data, x='sepal_length', y='sepal_width', hue='species')
plt.title('Iris: Sepal Length vs Sepal Width')
plt.tight_layout()
plt.show()Seaborn erstellt die Legende automatisch. Dies entspricht dem mehrfachen Aufruf von plt.scatter() mit verschiedenen Farben.
Regressionslinie mit sns.regplot() hinzufügen
sns.regplot() kombiniert ein Streudiagramm mit einer angepassten Regressionslinie und einem Konfidenzband:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=10)
x = np.linspace(1, 10, 30)
y = 3 * x + rng.normal(scale=4, size=30)
data = pd.DataFrame({'x': x, 'y': y})
sns.regplot(data=data, x='x', y='y', scatter_kws={'alpha': 0.6}, line_kws={'color': 'red'})
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter Plot with Regression Line')
plt.tight_layout()
plt.show()Der schattierte Bereich um die Linie ist ein 95-%-Konfidenzintervall. Mit ci=None lässt er sich entfernen.
Mehrere Gruppen darstellen
Mit Matplotlib
Rufen Sie plt.scatter() einmal pro Gruppe auf und setzen Sie bei jedem Aufruf label=:
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=0)
groups = {
'Group A': (2, 3),
'Group B': (6, 6),
'Group C': (9, 2),
}
for name, (cx, cy) in groups.items():
x = rng.normal(loc=cx, scale=0.6, size=30)
y = rng.normal(loc=cy, scale=0.6, size=30)
plt.scatter(x, y, s=50, alpha=0.7, label=name)
plt.legend()
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Three Distinct Clusters')
plt.tight_layout()
plt.show()Jeder scatter()-Aufruf wählt automatisch die nächste Farbe aus Matplotlibs Standard-Farbzyklus.
Mit Seaborn
Übergeben Sie einen DataFrame und verwenden Sie hue= sowie optional style=, um Gruppen zu unterscheiden:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=1)
rows = []
for group, (cx, cy) in [('A', (2, 3)), ('B', (6, 6)), ('C', (9, 2))]:
for _ in range(25):
rows.append({'x': rng.normal(cx, 0.6), 'y': rng.normal(cy, 0.6), 'group': group})
df = pd.DataFrame(rows)
sns.scatterplot(data=df, x='x', y='y', hue='group', style='group')
plt.title('Three Clusters — Seaborn Multi-Group')
plt.tight_layout()
plt.show()style='group' weist jeder Gruppe zusätzlich zur Farbe eine eigene Markerform zu, was Lesern hilft, die in Schwarzweiß drucken.
Streudiagramme im Machine Learning
Streudiagramme dienen nicht nur der Erkundung — sie sind Teil des zentralen ML-Workflows.
1. Lineare Zusammenhänge vor der Regression prüfen
Bevor Sie ein lineares Regressionsmodell trainieren, sollten Sie Eingabemerkmale gegen die Zielvariable plotten. Ein annähernd lineares Streumuster deutet darauf hin, dass lineare Regression geeignet ist:
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=5)
house_size = rng.uniform(50, 300, 60) # square metres
house_price = 2000 * house_size + rng.normal(scale=40000, size=60) # EUR
plt.scatter(house_size, house_price, alpha=0.6, s=50)
plt.xlabel('House Size (m²)')
plt.ylabel('Price (EUR)')
plt.title('House Size vs Price — linear pattern suggests linear regression')
plt.tight_layout()
plt.show()Wenn das Streudiagramm eine Kurve statt einer Linie zeigt, sind möglicherweise polynomielle Merkmale oder ein anderes Modell erforderlich.
2. Cluster nach K-Means visualisieren
Nach dem Ausführen eines Clustering-Algorithmus wie K-Means können Sie jeden Punkt anhand seines Cluster-Labels einfärben, um die Trennung zu bestätigen:
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=8)
# Simulate cluster assignments from k-means
centers = [(1, 1), (5, 5), (9, 1)]
X, labels = [], []
for i, (cx, cy) in enumerate(centers):
X.extend(zip(rng.normal(cx, 0.7, 30), rng.normal(cy, 0.7, 30)))
labels.extend([i] * 30)
X = np.array(X)
labels = np.array(labels)
colors = ['tab:blue', 'tab:orange', 'tab:green']
for k in range(3):
mask = labels == k
plt.scatter(X[mask, 0], X[mask, 1], c=colors[k], s=50, alpha=0.7, label=f'Cluster {k}')
plt.legend()
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('K-Means Cluster Assignments')
plt.tight_layout()
plt.show()Klar getrennte Wolken bestätigen, dass der Algorithmus bedeutungsvolle Gruppierungen gefunden hat.
3. Vorhersagen eines Regressionsmodells auswerten
Stellen Sie nach dem Training eines Modells die tatsächlichen versus vorhergesagten Werte dar. Ein perfektes Modell ergibt Punkte entlang der Diagonalen y = x:
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=2)
# Simulate actual and predicted values from a trained model
actual = rng.uniform(10, 100, 50)
predicted = actual + rng.normal(scale=8, size=50) # model with some noise
plt.scatter(actual, predicted, alpha=0.6, s=60, edgecolors='black', linewidths=0.5)
# Draw the ideal y = x line
lim = [min(actual.min(), predicted.min()) - 5, max(actual.max(), predicted.max()) + 5]
plt.plot(lim, lim, 'r--', linewidth=1.5, label='Perfect prediction')
plt.xlim(lim)
plt.ylim(lim)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Actual vs Predicted — regression model evaluation')
plt.legend()
plt.tight_layout()
plt.show()Punkte, die zufällig um die Diagonale gestreut sind (kein systematischer Bogen oder Fächerform), bedeuten, dass die Fehler des Modells unverzerrt sind.
4. Dimensionsreduktion visualisieren (PCA / t-SNE)
Nach der Reduktion hochdimensionaler Daten auf zwei Dimensionen mit PCA oder t-SNE ist ein Streudiagramm die natürliche Darstellungsform. Jeder Punkt ist eine Beobachtung; die Farbe gibt das Klassen-Label an:
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=20)
# Simulate 2-D PCA output for three classes
class_data = {
'Class 0': ((-3, 0), 0.8),
'Class 1': ((0, 3), 0.8),
'Class 2': ((3, 0), 0.8),
}
for label, ((cx, cy), spread) in class_data.items():
x = rng.normal(cx, spread, 40)
y = rng.normal(cy, spread, 40)
plt.scatter(x, y, s=30, alpha=0.7, label=label)
plt.legend()
plt.xlabel('PC 1')
plt.ylabel('PC 2')
plt.title('PCA Projection — 2D visualization of high-dimensional data')
plt.tight_layout()
plt.show()Cluster, die sich nach der Reduktion klar trennen, legen nahe, dass die Klassen durch die ursprünglichen Merkmale tatsächlich unterscheidbar sind.
Streudiagramme in einer Datei speichern
Verwenden Sie plt.savefig() vor plt.show() — der Aufruf von show() zuerst löscht die Grafik:
import matplotlib.pyplot as plt
import numpy as np
rng = np.random.default_rng(seed=99)
x = rng.random(50)
y = rng.random(50)
plt.scatter(x, y, alpha=0.7, s=60)
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Scatter Plot')
plt.tight_layout()
plt.savefig('scatter.png', dpi=150) # raster — good for web
plt.savefig('scatter.pdf') # vector — best for publications
plt.show()Verwenden Sie dpi=300 für druckqualitätssichere PNG-Bilder.
Wann welche Bibliothek verwenden
| Situation | Empfohlenes Werkzeug |
|---|---|
| Schnelles Einzel-Plot mit NumPy-Arrays | matplotlib.pyplot.scatter() |
| Arbeiten mit Pandas DataFrames | seaborn.scatterplot() |
| Farb- oder Größenkodierung pro Punkt benötigt | matplotlib.pyplot.scatter() |
| Automatische Gruppierung nach einer Spalte gewünscht | seaborn.scatterplot(hue=...) |
| Eingebaute Regressionslinie gewünscht | seaborn.regplot() |
| Tiefgehende Matplotlib-Anpassung | fig, ax = plt.subplots() dann ax.scatter() |
Für einen vollständigen Einblick in die Matplotlib-Streudiagramm-Parameter — einschließlich logarithmischer Skalen, Annotationen, Markerformen sowie dem Vergleich von scatter() und plot() — lesen Sie das Kapitel Matplotlib Scatter Plots.
Häufige Fehlerquellen
Variable nicht definiert. Jeder Code-Ausschnitt in diesem Kapitel ist eigenständig. Wenn Sie Ausschnitte kombinieren, stellen Sie sicher, dass x und y im selben Skript definiert sind, bevor Sie plt.scatter() aufrufen.
Grafik zwischen Plots nicht geleert. Nach plt.show() löscht Matplotlib die Grafik. Wenn Sie Ausschnitte in einem Jupyter-Notebook ausführen, erstellt jede Zelle automatisch eine neue Grafik. In einem einfachen Python-Skript rufen Sie plt.figure() auf, um einen neuen Plot zu starten, wenn Sie mehrere separate Diagramme möchten.
Überlagerung von Punkten. Bei vielen übereinanderliegenden Punkten sieht das Diagramm wie ein ausgefüllter Klecks aus. Beheben Sie dies mit alpha=0.3, um die Dichte sichtbar zu machen, oder wechseln Sie zu plt.hexbin() für 2-D-Histogramm-Binning.
Fehlende Farbskala. Wenn Sie ein Array an c übergeben, fügen Sie immer plt.colorbar() hinzu — ohne diese können Leser die Farbskala nicht interpretieren.
Verwandte Kapitel
- Matplotlib Scatter Plots — vollständige Matplotlib-Streudiagramm-Referenz: logarithmische Achsen, Annotationen, Blasendiagramme, Randfarben, Vergleich von
scatter()undplot() - Linear Regression — lineares Modell in Python anpassen und interpretieren
- K-Means Clustering — Daten in Gruppen aufteilen und mit Streudiagrammen visualisieren
- Data Distribution — die Form Ihrer Daten vor der Modellierung verstehen
- Matplotlib Histograms — Verteilung einer einzelnen Variable visualisieren
- Matplotlib Line Plots — Trends über eine kontinuierliche geordnete Variable
- Train / Test Split — Daten vor dem Training und der Auswertung von Modellen aufteilen