Ich muss die Dateien precess.java:
/*
* 4/28/24
* Final
*/
package Final;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
/*
* To DO
* Add labels to program
* One hot encode labels
* xavier insantztion
*/
public class PreProcess {
public static double[][] processImages(int numImagesToRead) throws FileNotFoundException, IOException {
String filePath = "C:\\Users\\Mark\\APCSA\\Final\\samples\\train-images.idx3-ubyte";
try (DataInputStream inputStream = new DataInputStream(new FileInputStream(filePath))) {
int magicNumber = inputStream.readInt();
if (magicNumber != 0x00000803) {
System.err.println("Invalid magic number. This may not be a valid image file.");
return null;
}
int numImages = inputStream.readInt();
int numRows = inputStream.readInt();
int numColumns = inputStream.readInt();
System.out.println("Processing " + numImagesToRead + " " + numRows + "x" + numColumns + " images");
byte[][][] images = new byte[numImages][numRows][numColumns];
for (int i = 0; i < numImagesToRead; i++) { // Changed loop condition
for (int row = 0; row < numRows; row++) {
for (int col = 0; col < numColumns; col++) {
images[row][col] = inputStream.readByte();
}
}
if (i % 10 == 0) {
updateProgress(i, numImagesToRead);
}
}
double[][] orderedImages = new double[numImagesToRead][numRows * numColumns]; // Changed to numImagesToRead
for (int i = 0; i < numImagesToRead; i++) {
orderedImages = minMaxNormalization(flaten(images, numRows, numColumns));
}
System.out.println("");
System.out.println("Finished processing images!");
return transpose(orderedImages);
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
public static double[][] processLabels(int numLabelsToRead) {
String labelFilePath = "C:\\Users\\Mark\\APCSA\\Final\\samples\\train-labels.idx1-ubyte";
try (DataInputStream inputStream = new DataInputStream(new FileInputStream(labelFilePath))) {
int magicNumber = inputStream.readInt();
if (magicNumber != 0x00000801) {
System.err.println("Invalid magic number. This may not be a valid labels file.");
return null;
}
int numLabels = numLabelsToRead;
byte[] labels = new byte[numLabels];
inputStream.read(labels);
System.out.println("Processing " + numLabels + " labels");
double[] orderedLabels = new double[numLabels];
for (int i = 0; i < numLabels; i++) {
orderedLabels = labels;
}
double[] categories = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};
System.out.println("Finiished processing labels!");
return oneHotEncode(orderedLabels, categories);
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
public static byte[] flaten(byte[][] image, int numRows, int numColumns) {
byte[] temp = new byte[numColumns * numRows];
int i = 0;
for (int row = 0; row < numRows; row++) {
for (int col = 0; col < numColumns; col++) {
temp = image[row][col];
i++;
}
}
return temp;
}
public static double[] minMaxNormalization(byte[] image) {
byte minPixelValue = Byte.MAX_VALUE;
byte maxPixelValue = Byte.MIN_VALUE;
for (byte pixelValue : image) {
if (pixelValue < minPixelValue) {
minPixelValue = pixelValue;
}
if (pixelValue > maxPixelValue) {
maxPixelValue = pixelValue;
}
}
double[] newImages = new double[image.length];
for (int i = 0; i < image.length; i++) {
newImages = ((image- minPixelValue)/(maxPixelValue - minPixelValue));
}
return newImages;
}
public static double[][] transpose(double[][] array) {
int length = array.length;
int imageSize = array[0].length;
double[][] transposedMatrix = new double[imageSize][length];
for (int i = 0; i < imageSize; i++) {
for (int j = 0; j < length; j++) {
transposedMatrix[j] = array[j];
}
}
return transposedMatrix;
}
public static double[][] oneHotEncode(double[] labels, double[] categories) {
double[][] result = new double[categories.length][labels.length];
for (int i = 0; i < categories.length; i++) {
for (int j = 0; j < labels.length; j++) {
if (labels[j] == categories[i]) {
result[i][j] = 1.0;
}
}
}
return result;
}
public static void updateProgress(int currentStep, int totalSteps) {
double progress = (double) currentStep / totalSteps;
int barLength = 100;
System.out.print("\r[");
int progressChars = (int) (progress * barLength);
for (int i = 0; i < barLength; i++) {
if (i < progressChars) {
System.out.print("=");
} else {
System.out.print(" ");
}
}
System.out.printf("] %.2f%%", progress * 100);
}
// public static void accuracy() {
// }
public static double[][] dot(double[][] matrixA, double[][] matrixB) {
double[][] newMatrix = new double[matrixA.length][matrixB[0].length];
for (int i = 0; i < newMatrix.length; i++) {
for (int j = 0; j < newMatrix[i].length; j++) {
for (int k = 0; k < matrixB.length; k++) {
newMatrix[i][j] += matrixA[i][k] * matrixB[k][j];
}
}
}
return newMatrix;
}
public static double[][] matrixCoefficientMultiplication(double[][] matrix, double coefficent) {
double[][] newMatrix = new double[matrix.length][matrix[0].length];
for (int i = 0; i < newMatrix.length; i++) {
for (int j = 0; j < newMatrix[i].length; j++) {
newMatrix[i][j] = matrix[i][j] * coefficent;
}
}
return newMatrix;
}
public static double[][] matrixOperations(double[][] matrixA, double[][] matrixB, boolean subtraction) {
double[][] newMatrix = new double[matrixA.length][matrixA[0].length];
double[][] newMatrixB;
if (subtraction) {
newMatrixB = matrixCoefficientMultiplication(matrixB, -1.0);
}
else {
newMatrixB = matrixB;
}
for (int i = 0; i < newMatrix.length; i++) {
for (int j = 0; j < newMatrix[i].length; j++) {
newMatrix[i][j] = matrixA[i][j] + newMatrixB[i][j];
}
}
return newMatrix;
}
public static double[][] matrixExp(double[][] matrix) {
double[][] newMatrix = new double[matrix.length][matrix[0].length];
for (int i = 0; i < newMatrix.length; i++) {
for (int j = 0; j < newMatrix[i].length; j++) {
newMatrix[i][j] = Math.exp(matrix[i][j]);
}
}
return newMatrix;
}
public static double sum(double[][] matrix) {
double sum = 0;
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++) {
sum += matrix[i][j];
}
}
return sum;
}
public static double[][] sumSecondAxis(double[][] array) {
double[][] sums = new double[array.length][1];
for (int i = 0; i < array.length; i++) {
double sum = 0.0;
for (int j = 0; j < array[i].length; j++) {
sum += array[i][j];
}
sums[i][0] = sum;
}
return sums;
}
public static double[][] reshape(double[][] matrix) {
double[][] newMatrix = new double[matrix.length * matrix[0].length][1];
int k = 0;
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++) {
newMatrix[k][0] = matrix[i][j];
k++;
}
}
return newMatrix;
}
public static double[] reshape(double[][] matrix, int numColumns) {
double[] newMatrix = new double[matrix.length * matrix[0].length];
int k = 0;
for (int i = 0; i < matrix.length; i++) {
for (int j = 0; j < matrix[i].length; j++) {
newMatrix[k] = matrix[i][j];
k++;
}
}
return newMatrix;
}
public static double[][] copyAcross(double[][] matrix, int numColumns) {
double[][] result = new double[matrix.length][numColumns];
for (int i = 0; i < matrix.length; i++) {
double value = matrix[i][0];
for (int j = 0; j < numColumns; j++) {
result[i][j] = value;
}
}
return result;
}
}
< /code>
und main.java in einem Ordner namens Final: < /p>
package Final;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Random;
import java.util.Scanner;
class ActivationFunction {
// possibly swith to rmsprop for better op
public static double sigmoid(double x) {
return 1 / (1 + Math.exp(-x));
}
public static double sigmoidDerivative(double x) {
return sigmoid(x) * (1 - sigmoid(x));
}
public static double[][] rectifiedLinearUnit(double[][] matrix) {
double[][] newMatrix = new double[matrix.length][matrix[0].length];
for (int i = 0; i < newMatrix.length; i++) {
for (int j = 0; j < newMatrix[i].length; j++) {
if (matrix[i][j]
Ich folgte einem Python -Tutorial und diesem Kommentar, konnte es aber nicht verwenden.Epoch: 0 Accuarcy: 0.1035
Epoch: 50 Accuarcy: 0.1015
Epoch: 100 Accuarcy: 0.1025
Epoch: 150 Accuarcy: 0.1035
Epoch: 200 Accuarcy: 0.105
Epoch: 250 Accuarcy: 0.104
Epoch: 300 Accuarcy: 0.103
Epoch: 350 Accuarcy: 0.1035
< /code>
Ich habe versucht, überall nachzuschauen, und es gibt keine Java -Implementierungen eines neuronalen Netzwerks. Können Sie bitte Änderungen vorschlagen oder warum dies möglicherweise geschieht? Jede Hilfe wäre sehr geschätzt, da dies für ein endgültiges Projekt ist und bevor Sie mich anschreien, entscheide ich mich nicht bereitwillig, dies in Java umzusetzen.
MNIST Image Classification Gradient Descent Neural Network funktioniert nicht ⇐ Java
-
- Similar Topics
- Replies
- Views
- Last post