@@ -25,8 +25,8 @@ public class NeuralNetworkRunner {
2525
2626 private static final float [] mouseYBins = {-40 , -20 , -10 , -4 , -2 , 0 , 2 , 4 , 10 , 20 , 40 };
2727
28- private static final int frame_x = 280 ;
29- private static final int frame_y = 150 ;
28+ private static final int frame_x = 80 ;
29+ private static final int frame_y = 45 ;
3030
3131 private final Minecraft mc = Minecraft .getMinecraft ();
3232 private final KeyBinding [] keybinds = {
@@ -38,7 +38,7 @@ public class NeuralNetworkRunner {
3838 private boolean running = false ;
3939 private final Gson gson = new Gson ();
4040
41- private final int TIMESTEPS = 60 ;
41+ private final int TIMESTEPS = 20 ;
4242 private final Deque <float [][][]> frameBuffer = new ArrayDeque <>();
4343
4444 private ExecutorService executor ;
@@ -51,6 +51,8 @@ public boolean running() {
5151
5252 public void startRunning () {
5353 running = true ;
54+ latestActions = null ;
55+
5456 frameBuffer .clear ();
5557 isInferenceRunning = false ;
5658 executor = Executors .newSingleThreadExecutor ();
@@ -81,15 +83,15 @@ public void onTick(TickEvent.ClientTickEvent event) {
8183 frameBuffer .removeFirst ();
8284 }
8385
84- // Convert to 96 -frame array
85- float [][][][][] frames96 = new float [ 1 ] [TIMESTEPS ][frame_x ][frame_y ][3 ];
86+ // Convert to TIMESTEPS -frame array
87+ float [][][][] frames = new float [TIMESTEPS ][frame_x ][frame_y ][3 ];
8688 int i = 0 ;
8789 for (float [][][] f : frameBuffer ) {
88- frames96 [ 0 ] [i ++] = f ;
90+ frames [i ++] = f ;
8991 }
9092
9193 if (executor != null && !isInferenceRunning && latestActions == null ) {
92- float [][][][][] framesToSend = frames96 .clone (); // clone to avoid mutation
94+ float [][][][] framesToSend = frames .clone (); // clone to avoid mutation
9395 isInferenceRunning = true ;
9496 executor .submit (() -> {
9597 List <Float > result = sendToInferenceServer (framesToSend );
@@ -98,18 +100,23 @@ public void onTick(TickEvent.ClientTickEvent event) {
98100 });
99101 }
100102
101- if (latestActions == null ) return ;
102-
103- List <Float > actions = latestActions ;
103+ if (latestActions == null ) {
104+ for (KeyBinding key : keybinds ) {
105+ KeyBinding .setKeyBindState (key .getKeyCode (), false );
106+ }
107+ KeyBinding .setKeyBindState (mc .gameSettings .keyBindAttack .getKeyCode (), false );
108+ KeyBinding .setKeyBindState (mc .gameSettings .keyBindUseItem .getKeyCode (), false );
109+ return ;
110+ }
104111
105112 for (int k = 0 ; k < 5 ; k ++) {
106- boolean pressed = actions .get (k ) > 0.5f ;
113+ boolean pressed = latestActions .get (k ) > 0.5f ;
107114 KeyBinding .setKeyBindState (keybinds [k ].getKeyCode (), pressed );
108115 }
109116
110117 // Handle left and right clicks (indexes 5 and 6)
111- boolean leftClick = actions .get (5 ) > 0.5f ;
112- boolean rightClick = actions .get (6 ) > 0.5f ;
118+ boolean leftClick = latestActions .get (5 ) > 0.5f ;
119+ boolean rightClick = latestActions .get (6 ) > 0.5f ;
113120
114121 KeyBindUtils .setKeyBindState (mc .gameSettings .keyBindAttack , leftClick );
115122 KeyBindUtils .setKeyBindState (mc .gameSettings .keyBindUseItem , rightClick );
@@ -121,7 +128,7 @@ public void onTick(TickEvent.ClientTickEvent event) {
121128 int yawIdx = 0 ;
122129 float yawMax = -1 ;
123130 for (int k = 0 ; k < mouseXBins .length ; k ++) {
124- float val = actions .get (yawStart + k );
131+ float val = latestActions .get (yawStart + k );
125132 if (val > yawMax ) {
126133 yawMax = val ;
127134 yawIdx = k ;
@@ -131,7 +138,7 @@ public void onTick(TickEvent.ClientTickEvent event) {
131138 int pitchIdx = 0 ;
132139 float pitchMax = -1 ;
133140 for (int k = 0 ; k < mouseYBins .length ; k ++) {
134- float val = actions .get (pitchStart + k );
141+ float val = latestActions .get (pitchStart + k );
135142 if (val > pitchMax ) {
136143 pitchMax = val ;
137144 pitchIdx = k ;
@@ -143,8 +150,8 @@ public void onTick(TickEvent.ClientTickEvent event) {
143150
144151 // Print descriptive action log
145152 System .out .printf ("Actions: [W: %b, A: %b, S: %b, D: %b, Space: %b, LeftClick: %b, RightClick: %b, YawDelta: %.2f, PitchDelta: %.2f]%n" ,
146- actions .get (1 ) > 0.5f , actions .get (0 ) > 0.5f , actions .get (2 ) > 0.5f , actions .get (3 ) > 0.5f , actions .get (4 ) > 0.5f ,
147- actions .get (5 ) > 0.5f , actions .get (6 ) > 0.5f , yawDelta , pitchDelta );
153+ latestActions .get (1 ) > 0.5f , latestActions .get (0 ) > 0.5f , latestActions .get (2 ) > 0.5f , latestActions .get (3 ) > 0.5f , latestActions .get (4 ) > 0.5f ,
154+ latestActions .get (5 ) > 0.5f , latestActions .get (6 ) > 0.5f , yawDelta , pitchDelta );
148155
149156 mc .thePlayer .rotationYaw += yawDelta ;
150157 mc .thePlayer .rotationPitch += pitchDelta ;
@@ -153,46 +160,62 @@ public void onTick(TickEvent.ClientTickEvent event) {
153160 }
154161
155162 private float [][][] captureAndResizeScreen () {
156- int fullW = mc .getFramebuffer ().framebufferTextureWidth ;
157- int fullH = mc .getFramebuffer ().framebufferTextureHeight ;
158- int targetW = frame_x ;
159- int targetH = frame_y ;
163+ int width = mc .getFramebuffer ().framebufferTextureWidth ;
164+ int height = mc .getFramebuffer ().framebufferTextureHeight ;
165+
166+ ByteBuffer buffer = BufferUtils . createByteBuffer ( width * height * 4 ); // 4 bytes per pixel (RGBA)
160167
161- ByteBuffer buffer = BufferUtils .createByteBuffer (fullW * fullH * 4 ); // 4 bytes per pixel (RGBA)
162168 GL11 .glReadBuffer (GL11 .GL_FRONT );
163- GL11 .glReadPixels (0 , 0 , fullW , fullH , GL12 .GL_BGRA , GL11 .GL_UNSIGNED_BYTE , buffer );
164- buffer .rewind ();
169+ GL11 .glReadPixels (0 , 0 , width , height , GL12 .GL_BGRA , GL11 .GL_UNSIGNED_BYTE , buffer );
170+
171+ // Grayscale pixel array
172+ int [] grayscaleFullRes = new int [width * height ];
173+ for (int i = 0 ; i < width * height ; i ++) {
174+ int b = buffer .get (i * 4 ) & 0xFF ;
175+ int g = buffer .get (i * 4 + 1 ) & 0xFF ;
176+ int r = buffer .get (i * 4 + 2 ) & 0xFF ;
177+ int gray = (r + g + b ) / 3 ;
178+ grayscaleFullRes [i ] = gray ;
179+ }
180+
181+ // Resize to 280x150 and convert to RGB float array normalized to [0,1]
182+ float [][][] result = new float [frame_y ][frame_x ][3 ];
165183
166- float [][][] result = new float [targetH ][targetW ][3 ];
167- int scaleX = fullW / targetW ;
168- int scaleY = fullH / targetH ;
184+ int scaleX = width / frame_x ;
185+ int scaleY = height / frame_y ;
169186
170- for (int ty = 0 ; ty < targetH ; ty ++) {
171- for (int tx = 0 ; tx < targetW ; tx ++) {
172- int r = 0 , g = 0 , b = 0 , count = 0 ;
187+ for (int ty = 0 ; ty < frame_y ; ty ++) {
188+ for (int tx = 0 ; tx < frame_x ; tx ++) {
189+ int sum = 0 ;
190+ int count = 0 ;
173191 for (int sy = 0 ; sy < scaleY ; sy ++) {
174192 for (int sx = 0 ; sx < scaleX ; sx ++) {
175- int x = tx * scaleX + sx ;
176- int y = (fullH - 1 ) - (ty * scaleY + sy ); // vertical flip
177- if (x < fullW && y < fullH ) {
178- int i = (y * fullW + x ) * 4 ;
179- b += buffer .get (i ) & 0xFF ;
180- g += buffer .get (i + 1 ) & 0xFF ;
181- r += buffer .get (i + 2 ) & 0xFF ;
193+ int sourceX = tx * scaleX + sx ;
194+ int sourceY = (height - 1 ) - (ty * scaleY + sy ); // Flip vertically
195+
196+ if (sourceX < width && sourceY < height ) {
197+ int gray = grayscaleFullRes [sourceY * width + sourceX ];
198+ sum += gray ;
182199 count ++;
183200 }
184201 }
185202 }
186- result [ty ][tx ][0 ] = r / (float ) (count * 255 );
187- result [ty ][tx ][1 ] = g / (float ) (count * 255 );
188- result [ty ][tx ][2 ] = b / (float ) (count * 255 );
203+
204+ float avgGray = sum / (float ) count ;
205+
206+ // Convert to "fake" RGB (grayscale repeated in 3 channels)
207+ result [ty ][tx ][0 ] = avgGray ;
208+ result [ty ][tx ][1 ] = avgGray ;
209+ result [ty ][tx ][2 ] = avgGray ;
189210 }
190211 }
191212
192213 return result ;
193214 }
194215
195- private List <Float > sendToInferenceServer (float [][][][][] frames96 ) {
216+
217+
218+ private List <Float > sendToInferenceServer (float [][][][] frames ) {
196219 try {
197220 URL url = new URL ("http://localhost:8000/predict" );
198221 HttpURLConnection conn = (HttpURLConnection ) url .openConnection ();
@@ -201,7 +224,7 @@ private List<Float> sendToInferenceServer(float[][][][][] frames96) {
201224 conn .setRequestProperty ("Content-Type" , "application/json" );
202225
203226 Map <String , Object > payload = new HashMap <>();
204- payload .put ("frame" , frames96 );
227+ payload .put ("frame" , frames );
205228
206229 String json = gson .toJson (payload );
207230 try (OutputStream os = conn .getOutputStream ()) {
0 commit comments