Skip to content

Commit de0f99e

Browse files
committed
80x45
1 parent 75d68d9 commit de0f99e

4 files changed

Lines changed: 74 additions & 57 deletions

File tree

0 Bytes
Binary file not shown.

.idea/compiler.xml

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/main/java/com/jelly/NeuralTrainer/data/ScreenDataCollector.java

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public class ScreenDataCollector {
3333
public void startRecording() {
3434
recording = true;
3535
String timestamp = new java.text.SimpleDateFormat("yyyy-MM-dd_HH-mm-ss-SSS").format(new java.util.Date());
36-
raw_dataset = new File("raw_dataset/screendata_" + mc.thePlayer.getName() + "_" + timestamp + ".txt");
36+
raw_dataset = new File("raw_dataset/screendata_80x45_" + mc.thePlayer.getName() + "_" + timestamp + ".txt");
3737
}
3838

3939
public void stopRecording() {
@@ -51,15 +51,9 @@ public void onTick(TickEvent.ClientTickEvent event) {
5151
float deltaYaw = Float.isNaN(lastYaw) ? 0 : wrapAngle(currentYaw - lastYaw);
5252
float deltaPitch = Float.isNaN(lastPitch) ? 0 : wrapAngle(currentPitch - lastPitch);
5353

54-
System.out.println("delta yaw: " + deltaYaw);
55-
System.out.println("delta print: " + deltaPitch);
56-
5754
int discretizedYaw = discretize(deltaYaw, new int[]{-90, -50, -20, -10, -4, -2, 0, 2, 4, 10, 20, 50, 90});
5855
int discretizedPitch = discretize(deltaPitch, new int[]{-40, -20, -10, -4, -2, 0, 2, 4, 10, 20, 40});
5956

60-
System.out.println("discretized yaw: " + discretizedYaw);
61-
System.out.println("discretized print: " + discretizedPitch);
62-
6357
lastYaw = currentYaw;
6458
lastPitch = currentPitch;
6559

@@ -96,9 +90,9 @@ public void onTick(TickEvent.ClientTickEvent event) {
9690
grayscaleFullRes[i] = gray;
9791
}
9892

99-
// Downscale to 280 x 150
100-
int targetWidth = 280;
101-
int targetHeight = 150;
93+
// Downscale to 80 x 45
94+
int targetWidth = 80;
95+
int targetHeight = 45;
10296
int scaleX = width / targetWidth;
10397
int scaleY = height / targetHeight;
10498

src/main/java/com/jelly/NeuralTrainer/run/NeuralNetworkRunner.java

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ public class NeuralNetworkRunner {
2525

2626
private static final float[] mouseYBins = {-40, -20, -10, -4, -2, 0, 2, 4, 10, 20, 40};
2727

28-
private static final int frame_x = 280;
29-
private static final int frame_y = 150;
28+
private static final int frame_x = 80;
29+
private static final int frame_y = 45;
3030

3131
private final Minecraft mc = Minecraft.getMinecraft();
3232
private final KeyBinding[] keybinds = {
@@ -38,7 +38,7 @@ public class NeuralNetworkRunner {
3838
private boolean running = false;
3939
private final Gson gson = new Gson();
4040

41-
private final int TIMESTEPS = 60;
41+
private final int TIMESTEPS = 20;
4242
private final Deque<float[][][]> frameBuffer = new ArrayDeque<>();
4343

4444
private ExecutorService executor;
@@ -51,6 +51,8 @@ public boolean running() {
5151

5252
public void startRunning() {
5353
running = true;
54+
latestActions = null;
55+
5456
frameBuffer.clear();
5557
isInferenceRunning = false;
5658
executor = Executors.newSingleThreadExecutor();
@@ -81,15 +83,15 @@ public void onTick(TickEvent.ClientTickEvent event) {
8183
frameBuffer.removeFirst();
8284
}
8385

84-
// Convert to 96-frame array
85-
float[][][][][] frames96 = new float[1][TIMESTEPS][frame_x][frame_y][3];
86+
// Convert to TIMESTEPS-frame array
87+
float[][][][] frames = new float[TIMESTEPS][frame_x][frame_y][3];
8688
int i = 0;
8789
for (float[][][] f : frameBuffer) {
88-
frames96[0][i++] = f;
90+
frames[i++] = f;
8991
}
9092

9193
if (executor != null && !isInferenceRunning && latestActions == null) {
92-
float[][][][][] framesToSend = frames96.clone(); // clone to avoid mutation
94+
float[][][][] framesToSend = frames.clone(); // clone to avoid mutation
9395
isInferenceRunning = true;
9496
executor.submit(() -> {
9597
List<Float> result = sendToInferenceServer(framesToSend);
@@ -98,18 +100,23 @@ public void onTick(TickEvent.ClientTickEvent event) {
98100
});
99101
}
100102

101-
if (latestActions == null) return;
102-
103-
List<Float> actions = latestActions;
103+
if (latestActions == null) {
104+
for (KeyBinding key : keybinds) {
105+
KeyBinding.setKeyBindState(key.getKeyCode(), false);
106+
}
107+
KeyBinding.setKeyBindState(mc.gameSettings.keyBindAttack.getKeyCode(), false);
108+
KeyBinding.setKeyBindState(mc.gameSettings.keyBindUseItem.getKeyCode(), false);
109+
return;
110+
}
104111

105112
for (int k = 0; k < 5; k++) {
106-
boolean pressed = actions.get(k) > 0.5f;
113+
boolean pressed = latestActions.get(k) > 0.5f;
107114
KeyBinding.setKeyBindState(keybinds[k].getKeyCode(), pressed);
108115
}
109116

110117
// Handle left and right clicks (indexes 5 and 6)
111-
boolean leftClick = actions.get(5) > 0.5f;
112-
boolean rightClick = actions.get(6) > 0.5f;
118+
boolean leftClick = latestActions.get(5) > 0.5f;
119+
boolean rightClick = latestActions.get(6) > 0.5f;
113120

114121
KeyBindUtils.setKeyBindState(mc.gameSettings.keyBindAttack, leftClick);
115122
KeyBindUtils.setKeyBindState(mc.gameSettings.keyBindUseItem, rightClick);
@@ -121,7 +128,7 @@ public void onTick(TickEvent.ClientTickEvent event) {
121128
int yawIdx = 0;
122129
float yawMax = -1;
123130
for (int k = 0; k < mouseXBins.length; k++) {
124-
float val = actions.get(yawStart + k);
131+
float val = latestActions.get(yawStart + k);
125132
if (val > yawMax) {
126133
yawMax = val;
127134
yawIdx = k;
@@ -131,7 +138,7 @@ public void onTick(TickEvent.ClientTickEvent event) {
131138
int pitchIdx = 0;
132139
float pitchMax = -1;
133140
for (int k = 0; k < mouseYBins.length; k++) {
134-
float val = actions.get(pitchStart + k);
141+
float val = latestActions.get(pitchStart + k);
135142
if (val > pitchMax) {
136143
pitchMax = val;
137144
pitchIdx = k;
@@ -143,8 +150,8 @@ public void onTick(TickEvent.ClientTickEvent event) {
143150

144151
// Print descriptive action log
145152
System.out.printf("Actions: [W: %b, A: %b, S: %b, D: %b, Space: %b, LeftClick: %b, RightClick: %b, YawDelta: %.2f, PitchDelta: %.2f]%n",
146-
actions.get(1) > 0.5f, actions.get(0) > 0.5f, actions.get(2) > 0.5f, actions.get(3) > 0.5f, actions.get(4) > 0.5f,
147-
actions.get(5) > 0.5f, actions.get(6) > 0.5f, yawDelta, pitchDelta);
153+
latestActions.get(1) > 0.5f, latestActions.get(0) > 0.5f, latestActions.get(2) > 0.5f, latestActions.get(3) > 0.5f, latestActions.get(4) > 0.5f,
154+
latestActions.get(5) > 0.5f, latestActions.get(6) > 0.5f, yawDelta, pitchDelta);
148155

149156
mc.thePlayer.rotationYaw += yawDelta;
150157
mc.thePlayer.rotationPitch += pitchDelta;
@@ -153,46 +160,62 @@ public void onTick(TickEvent.ClientTickEvent event) {
153160
}
154161

155162
private float[][][] captureAndResizeScreen() {
156-
int fullW = mc.getFramebuffer().framebufferTextureWidth;
157-
int fullH = mc.getFramebuffer().framebufferTextureHeight;
158-
int targetW = frame_x;
159-
int targetH = frame_y;
163+
int width = mc.getFramebuffer().framebufferTextureWidth;
164+
int height = mc.getFramebuffer().framebufferTextureHeight;
165+
166+
ByteBuffer buffer = BufferUtils.createByteBuffer(width * height * 4); // 4 bytes per pixel (RGBA)
160167

161-
ByteBuffer buffer = BufferUtils.createByteBuffer(fullW * fullH * 4); // 4 bytes per pixel (RGBA)
162168
GL11.glReadBuffer(GL11.GL_FRONT);
163-
GL11.glReadPixels(0, 0, fullW, fullH, GL12.GL_BGRA, GL11.GL_UNSIGNED_BYTE, buffer);
164-
buffer.rewind();
169+
GL11.glReadPixels(0, 0, width, height, GL12.GL_BGRA, GL11.GL_UNSIGNED_BYTE, buffer);
170+
171+
// Grayscale pixel array
172+
int[] grayscaleFullRes = new int[width * height];
173+
for (int i = 0; i < width * height; i++) {
174+
int b = buffer.get(i * 4) & 0xFF;
175+
int g = buffer.get(i * 4 + 1) & 0xFF;
176+
int r = buffer.get(i * 4 + 2) & 0xFF;
177+
int gray = (r + g + b) / 3;
178+
grayscaleFullRes[i] = gray;
179+
}
180+
181+
// Resize to 280x150 and convert to RGB float array normalized to [0,1]
182+
float[][][] result = new float[frame_y][frame_x][3];
165183

166-
float[][][] result = new float[targetH][targetW][3];
167-
int scaleX = fullW / targetW;
168-
int scaleY = fullH / targetH;
184+
int scaleX = width / frame_x;
185+
int scaleY = height / frame_y;
169186

170-
for (int ty = 0; ty < targetH; ty++) {
171-
for (int tx = 0; tx < targetW; tx++) {
172-
int r = 0, g = 0, b = 0, count = 0;
187+
for (int ty = 0; ty < frame_y; ty++) {
188+
for (int tx = 0; tx < frame_x; tx++) {
189+
int sum = 0;
190+
int count = 0;
173191
for (int sy = 0; sy < scaleY; sy++) {
174192
for (int sx = 0; sx < scaleX; sx++) {
175-
int x = tx * scaleX + sx;
176-
int y = (fullH - 1) - (ty * scaleY + sy); // vertical flip
177-
if (x < fullW && y < fullH) {
178-
int i = (y * fullW + x) * 4;
179-
b += buffer.get(i) & 0xFF;
180-
g += buffer.get(i + 1) & 0xFF;
181-
r += buffer.get(i + 2) & 0xFF;
193+
int sourceX = tx * scaleX + sx;
194+
int sourceY = (height - 1) - (ty * scaleY + sy); // Flip vertically
195+
196+
if (sourceX < width && sourceY < height) {
197+
int gray = grayscaleFullRes[sourceY * width + sourceX];
198+
sum += gray;
182199
count++;
183200
}
184201
}
185202
}
186-
result[ty][tx][0] = r / (float) (count * 255);
187-
result[ty][tx][1] = g / (float) (count * 255);
188-
result[ty][tx][2] = b / (float) (count * 255);
203+
204+
float avgGray = sum / (float) count;
205+
206+
// Convert to "fake" RGB (grayscale repeated in 3 channels)
207+
result[ty][tx][0] = avgGray;
208+
result[ty][tx][1] = avgGray;
209+
result[ty][tx][2] = avgGray;
189210
}
190211
}
191212

192213
return result;
193214
}
194215

195-
private List<Float> sendToInferenceServer(float[][][][][] frames96) {
216+
217+
218+
private List<Float> sendToInferenceServer(float[][][][] frames) {
196219
try {
197220
URL url = new URL("http://localhost:8000/predict");
198221
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
@@ -201,7 +224,7 @@ private List<Float> sendToInferenceServer(float[][][][][] frames96) {
201224
conn.setRequestProperty("Content-Type", "application/json");
202225

203226
Map<String, Object> payload = new HashMap<>();
204-
payload.put("frame", frames96);
227+
payload.put("frame", frames);
205228

206229
String json = gson.toJson(payload);
207230
try (OutputStream os = conn.getOutputStream()) {

0 commit comments

Comments
 (0)