361 lines
10 KiB
C
361 lines
10 KiB
C
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <math.h>
|
|
#include "raylib.h"
|
|
#include <fann.h>
|
|
|
|
#define MAX_SAMPLES 20000
|
|
#define SCREEN_WIDTH 800
|
|
#define SCREEN_HEIGHT 450
|
|
#define WEIGHTS_FILE "creature_brain.net"
|
|
|
|
struct game_state
|
|
{
|
|
Vector2 square_position;
|
|
Vector2 square_size;
|
|
float move_speed;
|
|
Vector2 food_position;
|
|
float food_radius;
|
|
struct fann *ann;
|
|
fann_type *recorded_inputs;
|
|
fann_type *recorded_outputs;
|
|
int sample_count;
|
|
bool is_ai_mode;
|
|
float dx;
|
|
float dy;
|
|
};
|
|
|
|
void
|
|
init_game(struct game_state *state);
|
|
|
|
void
|
|
game_logic(struct game_state *state);
|
|
|
|
void
|
|
game_renderer(struct game_state *state);
|
|
|
|
void
|
|
manual_control_mode(struct game_state *state);
|
|
|
|
void
|
|
ai_control_mode(struct game_state *state);
|
|
|
|
void
|
|
save_weights(struct game_state *state);
|
|
|
|
void
|
|
load_weights(struct game_state *state);
|
|
|
|
void
|
|
game_cleanup(struct game_state *state);
|
|
|
|
int
|
|
main(void)
|
|
{
|
|
struct game_state state = {0};
|
|
init_game(&state);
|
|
|
|
while (!WindowShouldClose())
|
|
{
|
|
game_logic(&state);
|
|
game_renderer(&state);
|
|
}
|
|
|
|
game_cleanup(&state);
|
|
return 0;
|
|
}
|
|
|
|
void
|
|
init_game(struct game_state *state)
|
|
{
|
|
InitWindow(SCREEN_WIDTH,
|
|
SCREEN_HEIGHT,
|
|
"Creature Sim - FANN Training");
|
|
|
|
state->square_position = (Vector2)
|
|
{
|
|
(float)SCREEN_WIDTH / 2,
|
|
(float)SCREEN_HEIGHT / 2
|
|
};
|
|
|
|
state->square_size = (Vector2)
|
|
{
|
|
40.0f,
|
|
40.0f
|
|
};
|
|
|
|
state->move_speed = 5.0f;
|
|
|
|
state->food_position = (Vector2)
|
|
{
|
|
(float)GetRandomValue(50, SCREEN_WIDTH - 50),
|
|
(float)GetRandomValue(50, SCREEN_HEIGHT - 50)
|
|
};
|
|
|
|
state->food_radius = 10.0f;
|
|
|
|
// FANN Setup: 3 layers (2 Inputs, 8 Hidden Neurons, 2 Outputs)
|
|
state->ann = fann_create_standard(3,
|
|
2,
|
|
8,
|
|
2);
|
|
|
|
state->recorded_inputs = malloc(MAX_SAMPLES * 2 * sizeof(fann_type));
|
|
state->recorded_outputs = malloc(MAX_SAMPLES * 2 * sizeof(fann_type));
|
|
state->sample_count = 0;
|
|
state->is_ai_mode = false;
|
|
|
|
fann_set_activation_function_hidden(state->ann,
|
|
FANN_SIGMOID_SYMMETRIC);
|
|
fann_set_activation_function_output(state->ann,
|
|
FANN_SIGMOID_SYMMETRIC);
|
|
|
|
SetTargetFPS(60);
|
|
}
|
|
|
|
void
|
|
game_logic(struct game_state *state)
|
|
{
|
|
state->dx = (state->food_position.x - state->square_position.x) / SCREEN_WIDTH;
|
|
state->dy = (state->food_position.y - state->square_position.y) / SCREEN_HEIGHT;
|
|
save_weights(state);
|
|
|
|
if (!state->is_ai_mode)
|
|
{
|
|
manual_control_mode(state);
|
|
}
|
|
else
|
|
{
|
|
ai_control_mode(state);
|
|
}
|
|
|
|
if (state->square_position.x < 0)
|
|
{
|
|
state->square_position.x = 0;
|
|
}
|
|
if (state->square_position.x > SCREEN_WIDTH - state->square_size.x)
|
|
{
|
|
state->square_position.x = SCREEN_WIDTH - state->square_size.x;
|
|
}
|
|
if (state->square_position.y < 0)
|
|
{
|
|
state->square_position.y = 0;
|
|
}
|
|
if (state->square_position.y > SCREEN_HEIGHT - state->square_size.y)
|
|
{
|
|
state->square_position.y = SCREEN_HEIGHT - state->square_size.y;
|
|
}
|
|
|
|
float sq_center_x = state->square_position.x + state->square_size.x / 2.0f;
|
|
float sq_center_y = state->square_position.y + state->square_size.y / 2.0f;
|
|
|
|
float dist = sqrt(pow(sq_center_x - state->food_position.x, 2) +
|
|
pow(sq_center_y - state->food_position.y, 2));
|
|
|
|
if (dist < (state->square_size.x / 2.0f + state->food_radius))
|
|
{
|
|
state->food_position.x = (float)GetRandomValue(50, SCREEN_WIDTH - 50);
|
|
state->food_position.y = (float)GetRandomValue(50, SCREEN_HEIGHT - 50);
|
|
}
|
|
}
|
|
|
|
void
|
|
game_renderer(struct game_state *state)
|
|
{
|
|
BeginDrawing();
|
|
ClearBackground(RAYWHITE);
|
|
|
|
if (!state->is_ai_mode)
|
|
{
|
|
DrawText("MANUAL MODE: Use arrow keys to chase the food.",
|
|
10,
|
|
10,
|
|
20,
|
|
DARKGRAY);
|
|
|
|
DrawText(TextFormat("Samples collected: %d / %d",
|
|
state->sample_count,
|
|
MAX_SAMPLES),
|
|
10,
|
|
40,
|
|
20,
|
|
MAROON);
|
|
|
|
DrawText("Press 'T' to Train | Press 'L' to Load weights",
|
|
10,
|
|
70,
|
|
20,
|
|
DARKGREEN);
|
|
}
|
|
else
|
|
{
|
|
DrawText("AI MODE: Neural Network is driving!",
|
|
10,
|
|
10,
|
|
20,
|
|
DARKBLUE);
|
|
|
|
DrawText("Press 'M' to return to Manual Mode.",
|
|
10,
|
|
40,
|
|
20,
|
|
DARKGRAY);
|
|
}
|
|
|
|
DrawText("Press 'N' anywhere to Save weights to file",
|
|
10,
|
|
SCREEN_HEIGHT - 30,
|
|
20,
|
|
GRAY);
|
|
|
|
DrawCircleV(state->food_position,
|
|
state->food_radius,
|
|
LIME);
|
|
|
|
DrawRectangleV(state->square_position,
|
|
state->square_size,
|
|
state->is_ai_mode ? DARKBLUE : MAROON);
|
|
|
|
EndDrawing();
|
|
}
|
|
|
|
void
|
|
manual_control_mode(struct game_state *state)
|
|
{
|
|
float out_x = 0.0f;
|
|
float out_y = 0.0f;
|
|
bool user_moved = false;
|
|
|
|
// Allow loading from file while in manual mode
|
|
load_weights(state);
|
|
|
|
if (IsKeyDown(KEY_RIGHT))
|
|
{
|
|
state->square_position.x += state->move_speed;
|
|
out_x = 1.0f;
|
|
user_moved = true;
|
|
}
|
|
if (IsKeyDown(KEY_LEFT))
|
|
{
|
|
state->square_position.x -= state->move_speed;
|
|
out_x = -1.0f;
|
|
user_moved = true;
|
|
}
|
|
if (IsKeyDown(KEY_UP))
|
|
{
|
|
state->square_position.y -= state->move_speed;
|
|
out_y = -1.0f;
|
|
user_moved = true;
|
|
}
|
|
if (IsKeyDown(KEY_DOWN))
|
|
{
|
|
state->square_position.y += state->move_speed;
|
|
out_y = 1.0f;
|
|
user_moved = true;
|
|
}
|
|
|
|
if (user_moved
|
|
&& state->sample_count < MAX_SAMPLES)
|
|
{
|
|
state->recorded_inputs[state->sample_count * 2 + 0] = state->dx;
|
|
state->recorded_inputs[state->sample_count * 2 + 1] = state->dy;
|
|
|
|
state->recorded_outputs[state->sample_count * 2 + 0] = out_x;
|
|
state->recorded_outputs[state->sample_count * 2 + 1] = out_y;
|
|
|
|
state->sample_count++;
|
|
}
|
|
|
|
if (IsKeyPressed(KEY_T)
|
|
&& state->sample_count > 10)
|
|
{
|
|
printf("Training FANN on %d samples...\n", state->sample_count);
|
|
|
|
struct fann_train_data *train_data =
|
|
fann_create_train(state->sample_count,
|
|
2,
|
|
2);
|
|
|
|
for(int i = 0;
|
|
i < state->sample_count;
|
|
i++)
|
|
{
|
|
train_data->input[i][0] = state->recorded_inputs[i * 2 + 0];
|
|
train_data->input[i][1] = state->recorded_inputs[i * 2 + 1];
|
|
train_data->output[i][0] = state->recorded_outputs[i * 2 + 0];
|
|
train_data->output[i][1] = state->recorded_outputs[i * 2 + 1];
|
|
}
|
|
|
|
fann_train_on_data(state->ann, train_data, 10000, 100, 0.01);
|
|
fann_destroy_train(train_data);
|
|
|
|
state->is_ai_mode = true;
|
|
}
|
|
}
|
|
|
|
void
|
|
ai_control_mode(struct game_state *state)
|
|
{
|
|
fann_type current_inputs[2] = { state->dx, state->dy };
|
|
fann_type *calc_out = fann_run(state->ann, current_inputs);
|
|
|
|
state->square_position.x += calc_out[0] * state->move_speed;
|
|
state->square_position.y += calc_out[1] * state->move_speed;
|
|
|
|
if (IsKeyPressed(KEY_M))
|
|
{
|
|
state->is_ai_mode = false;
|
|
}
|
|
}
|
|
|
|
void
|
|
save_weights(struct game_state *state)
|
|
{
|
|
if (IsKeyPressed(KEY_N))
|
|
{
|
|
if (fann_save(state->ann, WEIGHTS_FILE) == 0)
|
|
{
|
|
printf("SUCCESS: Neural weights saved to %s\n", WEIGHTS_FILE);
|
|
}
|
|
else
|
|
{
|
|
printf("ERROR: Failed to save weights to %s\n", WEIGHTS_FILE);
|
|
}
|
|
}
|
|
}
|
|
|
|
void
|
|
load_weights(struct game_state *state)
|
|
{
|
|
if (IsKeyPressed(KEY_L))
|
|
{
|
|
struct fann *loaded_ann = fann_create_from_file(WEIGHTS_FILE);
|
|
|
|
if (loaded_ann != NULL)
|
|
{
|
|
// Clean up the old blank network before replacing it
|
|
fann_destroy(state->ann);
|
|
state->ann = loaded_ann;
|
|
|
|
state->is_ai_mode = true;
|
|
|
|
printf("SUCCESS: Loaded weights from %s. Entering AI mode.\n",
|
|
WEIGHTS_FILE);
|
|
}
|
|
else
|
|
{
|
|
printf("ERROR: Could not load %s. Has it been saved yet?\n",
|
|
WEIGHTS_FILE);
|
|
}
|
|
}
|
|
}
|
|
|
|
void
|
|
game_cleanup(struct game_state *state)
|
|
{
|
|
free(state->recorded_inputs);
|
|
free(state->recorded_outputs);
|
|
fann_destroy(state->ann);
|
|
CloseWindow();
|
|
}
|