using System;
using System.Collections.Generic;

public class Onehot
{
    private List<string> tags = new List<string>();
    public List<List<float>> onehot = new List<List<float>>();
    private float totalNum;

    public void Initialize(List<string> inputTags)
    {
        tags = inputTags;
        totalNum = tags.Count;
        for (int i = 0; i < totalNum; i++)
        {
            List<float> thisOnehot = new List<float>();
            for (int j = 0; j < totalNum; j++) thisOnehot.Add(0f);
            thisOnehot[i] = 1f;
            onehot.Add(thisOnehot);
        }
    }

    public List<float> Encoder(string name = null)
    {
        if (name == null)
        {
            List<float> allZeroOnehot = new List<float>();
            for (int j = 0; j < totalNum; j++) allZeroOnehot.Add(0);
            return allZeroOnehot;
        }
        else
        {
            try
            {
                return onehot[tags.IndexOf(name)];
            }
            catch (ArgumentOutOfRangeException)
            {
                List<float> allZeroOnehot = new List<float>();
                for (int j = 0; j < totalNum; j++) allZeroOnehot.Add(0);
                return allZeroOnehot;
            }
        }
    }

    public string Decoder(List<float> thisOnehot)
    {
        return tags[onehot.IndexOf(thisOnehot)];
    }
}