Anyone have any experience training linear layers on top of embedding?


I am trying to build an open domain Q&A bot on a private dataset. Right now I use the embedding endpoint to embed every question and document. When a new question comes in, I embed the question w/ the API, and then I compute the cosine similarity of that question and every document. I then feed these documents into the context window for gpt-3 to write an answer.

This works ok but clearly fails in many cases. E.g. failing to retrieve relevant documents for a given question.

I have a bunch of labeled examples (between 1k and 10k). I’d like to train a very simple layer on top of these embeddings w/ the aim of improving relevance.

Right now I’m using a triplet loss and a linear layer, but I can’t seem to do any better than the original embeddings. It feels like the failure point here is the size of training set vs. the number of parameters of the linear layer. E.g. a linear layer mapping all 1536 #s to another set of 1536 numbers would have 2.3M parameters. I’m running some experiments now where I reduce the output of the linear layer (e.g. s.t. it maps to a 128 or 64 dimension embedding rather than 1536), which means I can get down to a few hundred thousand parameters.

That said, I can’t do any better than identity. Identity triplet loss is ≈ 1. Same with linear layer.

Does anyone have any experience w/ training a linear layer on top of openai’s embeddings for this use case?

I was considering moving a different service like a huggingfaces model such that I can fine tune the entire network on this smaller dataset, but not sure if that will be a good use of time.
Another alternative is to get much more training data from a less relevant set, embed that dataset / set of questions, train the linear layer on that, and then transfer it to my smaller dataset.
I also realize I could be doing something wrong w/ the training process (it is quite suspicious that the triplet loss for a linear layer makes no progress past a training loss of 1… You’d think it’d get somewhere…)

Anyway, felt like a common type of use case, so wanted to ask in case folks had experience.

1 Like

Eeep, it was indeed my code. It looks like I had an error in my code which meant that a hard negative sample could be the positive document itself.

Metrics appear more sensible now; train / validation reducing, etc.

Will report back if there are interesting findings.

1 Like

This is defiantly an interest area of mine (mapping the embeddings to other embeddings with a simple NN).

Another approach would be mapping correlations to other correlations. You correlate at the sentence, paragraph and page level. With each correlation level having context on embeddings surrounding it. You could create some insane prompts out of this without the NN (Neural Network). Here is a screenshot of my thoughts on this. Still pretty rough but would be a good alternative to NN’s. But still interested in NN’s too.

1 Like