Skip to content
Snippets Groups Projects
Commit 867372ed authored by Jakub Nenczak's avatar Jakub Nenczak
Browse files

Upload New File

parent 6a7fe70c
Branches
No related merge requests found
# test_one.py
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from collections import defaultdict
def get_working_directory(): #funkcja
if getattr(sys, 'frozen', False):
return os.path.dirname(sys.executable)
else:
return os.path.dirname(os.path.abspath(__file__))
os.chdir(get_working_directory())
#import funkcji z pliku model.py
from model import (
conv_forward,pool_forward,relu,softmax
)
#definicja nazw klas dla lepszej czytelności
CLASS_NAMES = [
"Airplane",
"Bird",
"Car",
"Cat",
"Deer",
"Dog",
"Horse",
"Monkey",
"Ship",
"Truck"
]
def read_stl10_images(path_to_bin, num_images):
#funkcja do wczytywania obrazów z pliku binarnego STL-10
with open(path_to_bin, 'rb') as f:
everything = np.fromfile(f, dtype=np.uint8)
images = everything.reshape(num_images, 3, 96, 96)
images = np.transpose(images, (0, 2, 3, 1))
return images
def read_stl10_labels(path_to_bin, num_labels):
#funkcja do wczytywania etykiet z pliku binarnego STL-10
with open(path_to_bin, 'rb') as f:
labels = np.fromfile(f, dtype=np.uint8)
labels = labels - 1
return labels
def load_stl10(root_path="./stl10_binary"):
#funkcja do ładowania zestawu danych STL-10
X_train = read_stl10_images(os.path.join(root_path, 'train_X.bin'), 5000)
y_train = read_stl10_labels(os.path.join(root_path, 'train_y.bin'), 5000)
X_test = read_stl10_images(os.path.join(root_path, 'test_X.bin'), 8000)
y_test = read_stl10_labels(os.path.join(root_path, 'test_y.bin'), 8000)
return X_train, y_train, X_test, y_test
#wczytywanie danych treningowych i testowych
X_train, y_train, X_test, y_test = load_stl10()
#normalizacja danych testowych
X_test = X_test / 255.0
checkpoint_path = "model_params.npz"
#wczytanie parametrów modelu z pliku .npz
data = np.load(checkpoint_path)
W1 = data["W1"]
b1 = data["b1"]
W2_conv = data["W2_conv"]
b2_conv = data["b2_conv"]
W_fc1 = data["W_fc1"]
b_fc1 = data["b_fc1"]
W_fc2 = data["W_fc2"]
b_fc2 = data["b_fc2"]
conv_hparams1 = {"stride":1, "pad":1}
pool_hparams1 = {"f":2, "stride":2}
conv_hparams2 = {"stride":1, "pad":1}
pool_hparams2 = {"f":2, "stride":2}
def forward_model(X, W1, b1, W2_conv, b2_conv, W_fc1, b_fc1, W_fc2, b_fc2):
#foward dla sieci
A1, _ = conv_forward(X, W1, b1, conv_hparams1, activation=relu)
A2, _ = pool_forward(A1, pool_hparams1, mode="max")
A3, _ = conv_forward(A2, W2_conv, b2_conv, conv_hparams2, activation=relu)
A4, _ = pool_forward(A3, pool_hparams2, mode="max")
A4_flat = A4.reshape(A4.shape[0], -1)
Z_fc1 = A4_flat.dot(W_fc1) + b_fc1
A_fc1 = relu(Z_fc1)
Z_fc2 = A_fc1.dot(W_fc2) + b_fc2
A_fc2 = softmax(Z_fc2)
return A_fc2
print("Wczytano parametry z pliku:", checkpoint_path)
while True:
index_str = input(f"Podaj indeks obrazu testowego (0 - {X_test.shape[0]-1}) lub 'q' aby wyjść: ")
if index_str.lower() == 'q':
break
try:
index = int(index_str)
if index < 0 or index >= X_test.shape[0]:
print(f"Nieprawidłowy indeks. Proszę podać liczbę z zakresu 0 - {X_test.shape[0]-1}.")
continue
except ValueError:
print("Proszę podać liczbę lub 'q'.")
continue
img = X_test[index]
true_label = y_test[index]
A_fc2_test = forward_model(img.reshape(1, 96, 96, 3),
W1, b1,
W2_conv, b2_conv,
W_fc1, b_fc1,
W_fc2, b_fc2)
pred_label = np.argmax(A_fc2_test, axis=1)[0]
correctness = "Poprawnie" if pred_label == true_label else "Niepoprawnie"
#wyświetlenie obrazu (obracanie o 90 stopni w lewo dla lepszej wizualizacji)
rotated_img = np.rot90(img, k=3)
plt.imshow(rotated_img)
plt.axis('off') #wyłączenie osi dla lepszej estetyki
plt.title(f"Prawdziwa etykieta: {CLASS_NAMES[true_label]}, Predykcja: {CLASS_NAMES[pred_label]} ({correctness})")
plt.show()
#koniec programu
print("Program zakończył działanie.")
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment