python - 从混合高斯分布生成二维样本数据集

标签 python numpy matplotlib dataset

我想生成一个二维样本数据集。我复制了此link中所述的代码并将其加倍以生成向量 X,Y,将它们分散为二维数据集,如下所示。但结果并不理想。事实上我想要如下图所示的东西。

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

mu = [1,4]
sigma = [2, 1]
p_i = [0.3, 0.7]
n = 1000

x = []
y=[]
for i in range(n):
    z_i = np.argmax(np.random.multinomial(1, p_i)) #np.random.multinomial(1,[0.3,0.5,0.2]) returns the result of an experiment
    #of rolling a dice. the result is as this: [1,0,0]. this means that the side one occurs in the experiment and the others 
    #not. the goal is choosing mu[i] in a random way
    x_i = np.random.normal(mu[z_i], sigma[z_i])
    x.append(x_i)

    
mu = [3,6]
sigma = [1, 2]
p_i = [0.6, 0.4]    

for i in range(n):
    z_i = np.argmax(np.random.multinomial(1, p_i)) #np.random.multinomial(1,[0.3,0.5,0.2]) returns the result of an experiment
    #of rolling a dice. the result is as this: [1,0,0]. this means that the side one occurs in the experiment and the others 
    #not. the goal is choosing mu[i] in a random way
    y_i = np.random.normal(mu[z_i], sigma[z_i])
    y.append(y_i)

plt.scatter(x, y)
plt.show()

` enter image description here

谁能帮帮我吗?

最佳答案

看起来您想要绘制的是从 2 个不同的 2D 高斯采样的数据。下面的代码可以绘制如下所示的模拟数据。请随意调整均值和协方差矩阵以满足您的需求。

from numpy.random import multivariate_normal

# First 2D gaussian:
mu = [1, 3]
cov = [[0.07, 0],[0, 1.8]]
x, y = np.random.multivariate_normal(mu, cov, 200).T

plt.figure(figsize=(10,6))
plt.scatter(x, y, s=5, color='blue')
ax = plt.gca()

# Second 2D gaussian:
mu = [2, 1]
cov = [[0.8, -0.4],[-0.4, 0.5]]
x, y = np.random.multivariate_normal(mu, cov, 200).T
plt.scatter(x, y, s=5, color='red')

plt.xlim([-2, 8])
plt.ylim([-6, 10]);

这会产生如下图所示的内容(不同的颜色,以便您可以看到图案):

enter image description here

关于python - 从混合高斯分布生成二维样本数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69065553/

相关文章:

python - 从相同的类名中提取文本(Python 网络抓取)

python - Django ManyToMany realtionship 不起作用

python - 需要一些关于为 Django 编写可重用应用程序的建议

python - 使用来自 (x,y,value) 三元组的数据创建 Numpy 二维数组

python - matplotlib - 允许条形图超出图表限制吗?

python - 在 python 中绘制流数据的最轻量级方法

Python_MySQL : AttributeError: 'function' object has no attribute 'translate'

python - 逻辑回归中的 Sigmoid(Tom Hope 的深度学习系统构建指南)

python - 在 Numpy 中将一维数组添加到三维数组

python - 楔形贴片位置未更新