새싹/TIL

[핀테커스] 231010 KNN 구현

jykim23 2023. 10. 10. 16:51

하.... 차근차근 하는데 쉬울줄 알았다.

결국 오늘 못하고 내일 완성해야지.

# KNN dataset
# 4개의 class를 가지는 dataset 만들기
n_classes = 4
n_features = 2
n_data = 100
centroid = np.random.uniform(0, 50, size=(n_features, n_classes))
K = 5

# 4개의 클래스 데이터셋
class_data = np.hstack([class_idx * np.ones(100,) for class_idx in range(n_classes)])
class_data = class_data.reshape(-1, 1)

# Target data
tmp_data_scale = 2
dataset = np.vstack([np.random.normal(centroid[:, i], tmp_data_scale, size=(n_data, n_features)) for i in range(n_classes)])
# print(dataset.shape) # (400, 2)

# target class & data
dataset = np.concatenate((dataset, class_data), axis=1)
# print(dataset.shape) # (400, 3)


# euclidean distance
sample_data = dataset[0, :n_features]

e_dists = np.linalg.norm((sample_data - dataset[:,:n_features]), axis=1).reshape(-1, 1)
# print(e_dists.shape) # (400, 1)
print(np.sort(e_dists, axis=0)[1:1+K]) # e_dists 오름정렬
dataset = np.concatenate((dataset, e_dists), axis=1)

# classify
print(np.sort(dataset[:,3], axis=0)[1:1+K].reshape(-1, 1)) # dataset의 e_dists(3) 기준 오름차순 정렬
# print(dataset[:,3==np.sort(dataset[:,3], axis=0)[1:6].reshape(-1, 1)])
# print(type(np.where(dataset[:,3]==(np.sort(e_dists, axis=0)[1:1+K]))))
tmp_tupl = np.where(dataset[:,3]==(np.sort(e_dists, axis=0)[1:1+K]))
print(np.where(dataset[:,3]==(np.sort(e_dists, axis=0)[1:1+K])))
print(dataset[tmp_tupl[1]])


# KNN visualization
# np.meshgrid
# matplot
fig, ax = plt.subplots(figsize=figsize)
for idx_class in range(n_classes):
    ax.scatter(dataset[idx_class*n_data:(idx_class+1)*n_data,0], dataset[idx_class*n_data:(idx_class+1)*n_data,1], alpha=0.5)

ax.scatter(sample_data[0], sample_data[1], marker='*', s=200) # type: ignore

'새싹 > TIL' 카테고리의 다른 글

[핀테커스] 231019 Bayes Theorem - 실습  (0) 2023.10.20
[핀테커스] 231011 KNN  (0) 2023.10.11
[핀테커스] 231006 matplot - iris  (0) 2023.10.06
[핀테커스] 231005 matplot 실습  (0) 2023.10.05
[핀테커스] 231004 matplot 실습  (0) 2023.10.04