/*

Macroglossum.cpp

This is the C++ code for the simulations reported in:

Balkenius, A., Kelber, A., and Balkenius, C. (2006). Modelling Multi-Modal Learning in a Hawkmoth.
In Nolfi, S. et al., From Animals to Animats 9: Proceedings of the Ninth International Conference on
the Simulation of Adaptive Behavior (SAB'06). (pp. 422-433). Berlin: Springer-Verlag.

Copyright (c) 2006 Anna Balkenius & Christian Balkenius

*/





#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>

const int model = 0;	// Model to test: 0 = Categorization model; 1 = Rescorla-Wagner; 2 = Independence

const float LR_Categorization = 0.05;
const float LR2_Categorization = 0.1;
const float LR_RescorlaWagner = 0.05;
const float LR_Independence = 0.04;

const int number_of_animals = 1000;
float exponent = 1;

float w_init_Categorization[5] = { 0.25, 1.0, 0.0, 0.1, 0.0 };
float w_init_RescorlaWagner[5] = { 0.1, 1.0, 0.0, 0.15, 0.05 };	
float w_init_Independence[5] = { 0.1, 1.0, 0.0, 0.15, 0.05 };	



//
// UTILITY FUNCTIONS
//

float
Random(float low, float high)
{
	return low + (float(rand())/float(RAND_MAX))*(high-low);
}										

//
// EXPERIMENT SET-UP
//

typedef struct {
	int		choices;
	char *	stimuli;
	float		cS1;
	float		cS2;
	float		n;
} Experiment;


const int no_of_experiments = 20;
const int no_of_phases = 10;

// The tables below defines the different experiments and the experimental results on real animal
// used to calculate the performance of the different models

Experiment
experiment[no_of_experiments][no_of_phases] = 
{
	// Preference Tests (Experiment 1-5)
	
	{
	1,		"B/Y",	88,	12,	25,
	0,		NULL,	0,	0,
	},
	{
	1,		"YH/YL",	63,	37,	38,
	0,		NULL,	0,	0,
	},
	{
	1,		"BH/BL",	52,	48,	21,
	0,		NULL,	0,	0,	0
	},
	{
	1,		"BL/YH",	96,	4,	25,
	0,		NULL,	0,	0,	0
	},
	{
	1,		"BH/YL",	90,	10,	10,
	0,		NULL,	0,	0,	0
	},
	
	// Simple Learning (Experiment 6-7)
	
	{
	50,		"B+/Y",	0,	0,	0,
	1,		"B/Y",	95,	5,	20,
	0,		NULL,	0,	0,	0
	},
	{
	50,		"B/Y+",	0,	0,	0,
	1,		"B/Y",	20,	80,	20,
	0,		NULL,	0,	0,	0
	},
	
	// Odour Learning with Same Colour (Experiment 8-11)
	
	{
	50,		"BH+/BL",	0,	0,	0,
	1,		"BH/BL",	50,	50,	14,
	0,		NULL,	0,	0,	0
	},
	{
	50,		"BH/BL+",	0,	0,	0,
	1,		"BH/BL",	58,	42,	24,
	0,		NULL,	0,	0,	0
	},
	{
	50,		"YH+/YL",	0,	0,	0,
	1,		"YH/YL",	81,	19,	42,
	0,		NULL,	0,	0,	0
	},
	{
	50,		"YH/YL+",	0,	0,	0,
	1,		"YH/YL",	29,	71,	28,
	0,		NULL,	0,	0,	0
	},
	
	// Combinations (Experiment 12-16)
	
	{
	50,		"YH+/BL",	0,	0,	0,
	1,		"YL/BH",	62,	38,	50,
	0,		NULL,	0,	0,	0
	},
	{
	50,		"YH+/BL",	0,	0,	0,
	1,		"Y/B",	100,	0,	18,
	0,		NULL,	0,	0,	0
	},
	
	{
	50,		"BL+/YH",	0,	0,	0,
	1,		"YL/BH",	0,	100,	21,
	0,		NULL,	0,	0,	0
	},
	{
	50,		"BL+/YH",	0,	0,	0,
	1,		"Y/B",	0,	100,	10,
	0,		NULL,	0,	0,	0
	},
	
	{
	50,		"YL+/BH",	0,	0,	0,
	1,		"YH/BL",	72,	28,	18,
	0,		NULL,	0,	0,	0
	},
	
	// Blocking Like (Experiment 17-20)
	
	{
	50,		"B/Y+",	0,	0,	0,
	50,		"BH/BL+",	0,	0,	0,
	1,		"BH/BL",	21,	79,	34,
	0,		NULL,	0,	0,	0
	},
	{
	50,		"B/Y+",	0,	0,	0,
	50,		"BH+/BL",	0,	0,	0,
	1,		"BH/BL",	79,	21,	19,
	0,		NULL,	0,	0,	0
	},
	{
	50,		"Y+",		0,	0,	0,
	50,		"YH/YL+",	0,	0,	0,
	1,		"YH/YL",	58,	42,	26,
	0,		NULL,	0,	0,	0
	},
	{
	50,		"Y+",		0,	0,	0,
	50,		"YH+/YL",	0,	0,	0,
	1,		"YH/YL",	50,	50,	14,
	0,		NULL,	0,	0,	0
	}
};



bool
Phase_SetUp(int e, int p, float stimulus[2][6], int & choices)
{
	if(experiment[e][p].stimuli == NULL)
		return false;

	// Get stimuli and iterations
	
	for(int i=0; i<6; i++)
	{
		stimulus[0][i] = 0;
		stimulus[1][i] = 0;
	}

	int s = 0;
	for(int i=0; i<strlen(experiment[e][p].stimuli); i++)
		switch(experiment[e][p].stimuli[i])
		{
			case 'Y':		stimulus[s][0] = 1; break;
			case 'B':		stimulus[s][1] = 1; break;
			case 'G':		stimulus[s][2] = 1; break;
			case 'H':		stimulus[s][3] = 1; break;
			case 'L':		stimulus[s][4] = 1; break;
			case '+':		stimulus[s][5] = 1; break;
			case '/':		s++;
			default:		break;
		};

	choices = experiment[e][p].choices;

	return true;
}



long results[no_of_experiments][no_of_phases][2];


void
Results_Reset()
{
	for(int e=0; e<no_of_experiments; e++)
		for(int p=0; p<no_of_phases; p++)
		{
			results[e][p][0] = 0;
			results[e][p][1] = 0;
		}
}

float sqr(float x) { return x*x; }


void
Results_List()
{
	float s[2][6];
	int choices;

	printf("\nResults (Choices)\n\n");
	printf("Exp  Phase     N    Stimuli           S1        S2        S1(%%)/S1(%%)\n");

	for(int e=0; e<no_of_experiments; e++)
	{
		printf("--------------------------------------------------------------------------\n");
		for(int p=0; Phase_SetUp(e, p, s, choices); p++)
		{
			if(p==0)
				printf("%2d   %d%10d    %-10s%10d%10d", e+1, p+1, choices, experiment[e][p].stimuli, results[e][p][0], results[e][p][1]);
			else
				printf("     %d%10d    %-10s%10d%10d", p+1, choices, experiment[e][p].stimuli, results[e][p][0], results[e][p][1]);

			float s = results[e][p][0]+results[e][p][1];
			if(s> 0 && experiment[e][p].cS1+ experiment[e][p].cS2 == 0)
			{
				printf("        %.0f/%.0f\n", 100*results[e][p][0]/s, 100*results[e][p][1]/s);
			}
			else if(s>0)
			{
				printf("        %.0f/%.0f\t\t%.0f/%.0f\n", 100*results[e][p][0]/s, 100*results[e][p][1]/s, experiment[e][p].cS1, experiment[e][p].cS2);
			}
			else
				printf("        -/-\n");
		}
	}
	printf("--------------------------------------------------------------------------\n");
	printf("n = %d, e = %.0f, ", number_of_animals, exponent);
	
	// Calculate total error
	
	float error = 0;
	float max_error = 0;
	for(int e=0; e<no_of_experiments; e++)
		for(int p=0; Phase_SetUp(e, p, s, choices); p++)
		{
			float s = results[e][p][0]+results[e][p][1];
			if(s> 0 && experiment[e][p].cS1+ experiment[e][p].cS2 != 0)
			{
				float err = fabsf(100*results[e][p][0]/s - experiment[e][p].cS1);
				error += err;
				if(err > max_error)
					max_error = err;
			}
		}

	printf("average error = %.2f%%, max error = %.2f%%\n", error/19.0, max_error);
	
	printf("model = %d\n", model);
}



//
// SIMULATE MOTH
//


float w[5];



void
Moth_Reset_Categorization()
{
	for(int i=0; i<5; i++)
		w[i] = w_init_Categorization[i];

	// Normalize

	float s = 0;
	for(int i=0; i<5; i++)
		s += w[i];
		
	if(s != 0)
		for(int i=0; i<5; i++)
			w[i] /= s;
}



void
Moth_Reset_RescorlaWagner()
{
	for(int i=0; i<5; i++)
		w[i] = w_init_RescorlaWagner[i];
}



void
Moth_Reset_Independence()
{
	for(int i=0; i<5; i++)
		w[i] = w_init_Independence[i];
}



int
Moth_Select(float stimuli[2][6])
{
	for(int i=0; i<100; i++)
	{
		int choice = int(Random(0, 2));
	
		float V = 0;
		for(int i=0; i<5; i++)
			V += w[i]*stimuli[choice][i];
		
		if(pow(V, exponent) > Random(0, 1))
			return choice;
	}
	
	return int(Random(0, 2));
}



void
Moth_Learn_Categorization(float * stimulus)
{
	// Do not learn when no reward

	if(stimulus[5] == 0)
		return;
		
	// Calculate expectation
	
	float E = 0;
	for(int i=0; i<5; i++)
		E += w[i]*stimulus[i];
	
	float delta = 1 - E;
	if(delta < 0)
		delta = 0;
	
	// Adjust weights
	
	for(int i=0; i<5; i++)
	{
		if(stimulus[i] > 0)
			w[i] += LR_Categorization * delta;
		else
			w[i] -= LR2_Categorization;
		if(w[i] < 0) w[i] = 0;
	}
	
	// Normalize

	float s = 0;
	for(int i=0; i<5; i++)
		s += w[i];
		
	if(s != 0)
		for(int i=0; i<5; i++)
			w[i] /= s;
}



void
Moth_Learn_RescorlaWagner(float * stimulus)
{
	float Rew = stimulus[5];

	// Calculate expectation
	
	float E = 0;
	for(int i=0; i<5; i++)
		E += w[i]*stimulus[i];
	
	float delta = Rew  - E;
	
	// Adjust weights
	
	for(int i=0; i<5; i++)
	{
		w[i] += LR_RescorlaWagner * delta * stimulus[i];
	}
}



void
Moth_Learn_Independence(float * stimulus)
{
	float Rew = stimulus[5];
	float delta = 2*Rew - 1;
	
	// Adjust weights
	
	for(int i=0; i<5; i++)
	{
		w[i] += LR_Independence * delta * stimulus[i];
	}
}



//
// RUN EXPERIMENTS
//

void
Experiment_List()
{
	float stimuli[2][6];
	int choices;

	for(int e=0; e<no_of_experiments; e++)
	{
		printf("\nExperiment %d\n", e);
		for(int p=0; Phase_SetUp(e, p, stimuli, choices); p++)
		{
			printf("\t%d: %5d\t\t", p, choices);
			for(int i=0; i<6; i++)
				printf("%.0f ", stimuli[0][i]);

			printf("\t-\t");
			for(int i=0; i<6; i++)
				printf("%.0f ", stimuli[1][i]);
			
			printf("\n");
		}
	}
}



void
Experiment_Run(int e, int n)
{
	int choices = 0;
	float stimuli[2][6];

	if(model == 0)
		exponent = 2;
		 
	for(int animal=0; animal < n; animal++)
	{
		if(model == 0)
			Moth_Reset_Categorization();
		else if(model == 1)
			Moth_Reset_RescorlaWagner();
		else
			Moth_Reset_Independence();

		for(int p=0; Phase_SetUp(e, p, stimuli, choices); p++)
		{
			for(int c=0; c < choices; c++)
			{
				int choice = random() % 2;
				
				 if(model != 3)
					choice = Moth_Select(stimuli);
		
				if(choice >= 0)
				{
					if(model == 0)
						Moth_Learn_Categorization(stimuli[choice]);
					else if(model == 1)
						Moth_Learn_RescorlaWagner(stimuli[choice]);
					else
						Moth_Learn_Independence(stimuli[choice]);
					
					results[e][p][choice]++;
				}
			}
		}
	}
}



int
main()
{
	Results_Reset();

	for(int e=0; e<no_of_experiments; e++)
		Experiment_Run(e, number_of_animals);
	
	Results_List();
}

