Classify text with embeddings

Overview

In this notebook, you’ll learn to use the embeddings produced by the Gemini API to train a model that can classify different types of newsgroup posts based on the topic.

Setup

Install the Google GenAI SDK

Install the Google GenAI SDK from npm.

$ npm install @google/genai

Setup your API key

You can create your API key using Google AI Studio with a single click.

Remember to treat your API key like a password. Don’t accidentally save it in a notebook or source file you later commit to GitHub. In this notebook we will be storing the API key in a .env file. You can also set it as an environment variable or use a secret manager.

Here’s how to set it up in a .env file:

$ touch .env
$ echo "GEMINI_API_KEY=<YOUR_API_KEY>" >> .env
Tip

Another option is to set the API key as an environment variable. You can do this in your terminal with the following command:

$ export GEMINI_API_KEY="<YOUR_API_KEY>"

Load the API key

To load the API key from the .env file, we will use the dotenv package. This package loads environment variables from a .env file into process.env.

$ npm install dotenv

Then, we can load the API key in our code:

const dotenv = require("dotenv") as typeof import("dotenv");

dotenv.config({
  path: "../.env",
});

const GEMINI_API_KEY = process.env.GEMINI_API_KEY ?? "";
if (!GEMINI_API_KEY) {
  throw new Error("GEMINI_API_KEY is not set in the environment variables");
}
console.log("GEMINI_API_KEY is set in the environment variables");
GEMINI_API_KEY is set in the environment variables
Note

In our particular case the .env is is one directory up from the notebook, hence we need to use ../ to go up one directory. If the .env file is in the same directory as the notebook, you can omit it altogether.

│
├── .env
└── examples
    └── Classify_text_with_embeddings.ipynb

Initialize SDK Client

With the new SDK, now you only need to initialize a client with you API key (or OAuth if using Vertex AI). The model is now set in each call.

const google = require("@google/genai") as typeof import("@google/genai");

const ai = new google.GoogleGenAI({ apiKey: GEMINI_API_KEY });

Select a model

Now select the model you want to use in this guide, either by selecting one in the list or writing it down. Keep in mind that some models, like the 2.5 ones are thinking models and thus take slightly more time to respond (cf. thinking notebook for more details and in particular learn how to switch the thiking off).

const tslab = require("tslab") as typeof import("tslab");

const MODEL_ID = "gemini-2.5-flash-preview-05-20";

Prepare dataset

The 20 Newsgroups Text Dataset contains 18,000 newsgroups posts on 20 topics divided into training and test sets. The split between the training and test datasets are based on messages posted before and after a specific date. For this tutorial, you will be using the subsets of the training and test datasets. You will preprocess and organize the data into Pandas dataframes.

const fs = require("fs") as typeof import("fs");
const path = require("path") as typeof import("path");
const tar = require("tar") as typeof import("tar");
const danfo = require("danfojs-node") as typeof import("danfojs-node");

// URL of the scikit-learn 20 Newsgroups dataset
const DATA_URL = "https://ndownloader.figshare.com/files/5975967";
const EXTRACT_PATH = "../assets/anomaly_detection";

async function downloadAndExtractDataset(): Promise<void> {
  if (fs.existsSync(EXTRACT_PATH)) {
    console.log("Dataset already exists. Skipping download.");
    return;
  }

  console.log("Downloading 20 Newsgroups dataset...");
  const response = await fetch(DATA_URL);
  const buffer = await response.arrayBuffer();

  console.log("Extracting dataset...");
  await fs.promises.mkdir(EXTRACT_PATH, { recursive: true });

  const zipPath = path.join(EXTRACT_PATH, "20news-bydate.tar.gz");
  fs.writeFileSync(zipPath, Buffer.from(buffer));

  await tar.x({
    file: zipPath,
    cwd: EXTRACT_PATH,
  });

  console.log("Dataset extracted.");
}

function loadTextFilesFromDir(dirPath: string): {
  data: string[];
  target: string[];
} {
  const categories = fs.readdirSync(dirPath);
  const data: string[] = [];
  const target: string[] = [];

  for (const category of categories) {
    const categoryPath = path.join(dirPath, category);
    if (fs.lstatSync(categoryPath).isDirectory()) {
      const files = fs.readdirSync(categoryPath);
      for (const file of files) {
        const filePath = path.join(categoryPath, file);
        const content = fs.readFileSync(filePath, "utf-8");
        data.push(content);
        target.push(category);
      }
    }
  }

  return { data, target };
}

await downloadAndExtractDataset();

const trainDir = path.join(EXTRACT_PATH, "20news-bydate-train");
const { data: trainData, target: trainTarget } = loadTextFilesFromDir(trainDir);
const trainDf = new danfo.DataFrame({
  data: trainData,
  target: trainTarget,
});

const testDir = path.join(EXTRACT_PATH, "20news-bydate-test");
const { data: testData, target: testTarget } = loadTextFilesFromDir(testDir);
const testDf = new danfo.DataFrame({
  data: testData,
  target: testTarget,
});
Dataset already exists. Skipping download.
/* eslint-disable @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call */
const classNames = trainDf.target.unique().values as string[];
console.log("Class names:", classNames);
Class names: [
  'alt.atheism',
  'comp.graphics',
  'comp.os.ms-windows.misc',
  'comp.sys.ibm.pc.hardware',
  'comp.sys.mac.hardware',
  'comp.windows.x',
  'misc.forsale',
  'rec.autos',
  'rec.motorcycles',
  'rec.sport.baseball',
  'rec.sport.hockey',
  'sci.crypt',
  'sci.electronics',
  'sci.med',
  'sci.space',
  'soc.religion.christian',
  'talk.politics.guns',
  'talk.politics.mideast',
  'talk.politics.misc',
  'talk.religion.misc'
]

Here is an example of what a data point from the training set looks like.

/* eslint-disable @typescript-eslint/no-unsafe-member-access */
const firstDoc = trainDf.loc({ rows: [0], columns: ["data"] });
const firstText = firstDoc.data.values[0] as string;

const idx = firstText.indexOf("Lines");

if (idx !== -1) {
  console.log(firstText.slice(idx));
} else {
  console.log('"Lines" not found in the first document.');
}
Lines: 290

Archive-name: atheism/resources
Alt-atheism-archive-name: resources
Last-modified: 11 December 1992
Version: 1.0

                              Atheist Resources

                      Addresses of Atheist Organizations

                                     USA

FREEDOM FROM RELIGION FOUNDATION

Darwin fish bumper stickers and assorted other atheist paraphernalia are
available from the Freedom From Religion Foundation in the US.

Write to:  FFRF, P.O. Box 750, Madison, WI 53701.
Telephone: (608) 256-8900

EVOLUTION DESIGNS

Evolution Designs sell the "Darwin fish".  It's a fish symbol, like the ones
Christians stick on their cars, but with feet and the word "Darwin" written
inside.  The deluxe moulded 3D plastic fish is $4.95 postpaid in the US.

Write to:  Evolution Designs, 7119 Laurel Canyon #4, North Hollywood,
           CA 91605.

People in the San Francisco Bay area can get Darwin Fish from Lynn Gold --
try mailing <figmo@netcom.com>.  For net people who go to Lynn directly, the
price is $4.95 per fish.

AMERICAN ATHEIST PRESS

AAP publish various atheist books -- critiques of the Bible, lists of
Biblical contradictions, and so on.  One such book is:

"The Bible Handbook" by W.P. Ball and G.W. Foote.  American Atheist Press.
372 pp.  ISBN 0-910309-26-4, 2nd edition, 1986.  Bible contradictions,
absurdities, atrocities, immoralities... contains Ball, Foote: "The Bible
Contradicts Itself", AAP.  Based on the King James version of the Bible.

Write to:  American Atheist Press, P.O. Box 140195, Austin, TX 78714-0195.
      or:  7215 Cameron Road, Austin, TX 78752-2973.
Telephone: (512) 458-1244
Fax:       (512) 467-9525

PROMETHEUS BOOKS

Sell books including Haught's "Holy Horrors" (see below).

Write to:  700 East Amherst Street, Buffalo, New York 14215.
Telephone: (716) 837-2475.

An alternate address (which may be newer or older) is:
Prometheus Books, 59 Glenn Drive, Buffalo, NY 14228-2197.

AFRICAN-AMERICANS FOR HUMANISM

An organization promoting black secular humanism and uncovering the history of
black freethought.  They publish a quarterly newsletter, AAH EXAMINER.

Write to:  Norm R. Allen, Jr., African Americans for Humanism, P.O. Box 664,
           Buffalo, NY 14226.

                                United Kingdom

Rationalist Press Association          National Secular Society
88 Islington High Street               702 Holloway Road
London N1 8EW                          London N19 3NL
071 226 7251                           071 272 1266

British Humanist Association           South Place Ethical Society
14 Lamb's Conduit Passage              Conway Hall
London WC1R 4RH                        Red Lion Square
071 430 0908                           London WC1R 4RL
fax 071 430 1271                       071 831 7723

The National Secular Society publish "The Freethinker", a monthly magazine
founded in 1881.

                                   Germany

IBKA e.V.
Internationaler Bund der Konfessionslosen und Atheisten
Postfach 880, D-1000 Berlin 41. Germany.

IBKA publish a journal:
MIZ. (Materialien und Informationen zur Zeit. Politisches
Journal der Konfessionslosesn und Atheisten. Hrsg. IBKA e.V.)
MIZ-Vertrieb, Postfach 880, D-1000 Berlin 41. Germany.

For atheist books, write to:

IBDK, Internationaler B"ucherdienst der Konfessionslosen
Postfach 3005, D-3000 Hannover 1. Germany.
Telephone: 0511/211216


                               Books -- Fiction

THOMAS M. DISCH

"The Santa Claus Compromise"
Short story.  The ultimate proof that Santa exists.  All characters and 
events are fictitious.  Any similarity to living or dead gods -- uh, well...

WALTER M. MILLER, JR

"A Canticle for Leibowitz"
One gem in this post atomic doomsday novel is the monks who spent their lives
copying blueprints from "Saint Leibowitz", filling the sheets of paper with
ink and leaving white lines and letters.

EDGAR PANGBORN

"Davy"
Post atomic doomsday novel set in clerical states.  The church, for example,
forbids that anyone "produce, describe or use any substance containing...
atoms". 

PHILIP K. DICK

Philip K. Dick Dick wrote many philosophical and thought-provoking short 
stories and novels.  His stories are bizarre at times, but very approachable.
He wrote mainly SF, but he wrote about people, truth and religion rather than
technology.  Although he often believed that he had met some sort of God, he
remained sceptical.  Amongst his novels, the following are of some relevance:

"Galactic Pot-Healer"
A fallible alien deity summons a group of Earth craftsmen and women to a
remote planet to raise a giant cathedral from beneath the oceans.  When the
deity begins to demand faith from the earthers, pot-healer Joe Fernwright is
unable to comply.  A polished, ironic and amusing novel.

"A Maze of Death"
Noteworthy for its description of a technology-based religion.

"VALIS"
The schizophrenic hero searches for the hidden mysteries of Gnostic
Christianity after reality is fired into his brain by a pink laser beam of
unknown but possibly divine origin.  He is accompanied by his dogmatic and
dismissively atheist friend and assorted other odd characters.

"The Divine Invasion"
God invades Earth by making a young woman pregnant as she returns from
another star system.  Unfortunately she is terminally ill, and must be
assisted by a dead man whose brain is wired to 24-hour easy listening music.

MARGARET ATWOOD

"The Handmaid's Tale"
A story based on the premise that the US Congress is mysteriously
assassinated, and fundamentalists quickly take charge of the nation to set it
"right" again.  The book is the diary of a woman's life as she tries to live
under the new Christian theocracy.  Women's right to own property is revoked,
and their bank accounts are closed; sinful luxuries are outlawed, and the
radio is only used for readings from the Bible.  Crimes are punished
retroactively: doctors who performed legal abortions in the "old world" are
hunted down and hanged.  Atwood's writing style is difficult to get used to
at first, but the tale grows more and more chilling as it goes on.

VARIOUS AUTHORS

"The Bible"
This somewhat dull and rambling work has often been criticized.  However, it
is probably worth reading, if only so that you'll know what all the fuss is
about.  It exists in many different versions, so make sure you get the one
true version.

                             Books -- Non-fiction

PETER DE ROSA

"Vicars of Christ", Bantam Press, 1988
Although de Rosa seems to be Christian or even Catholic this is a very
enlighting history of papal immoralities, adulteries, fallacies etc.
(German translation: "Gottes erste Diener. Die dunkle Seite des Papsttums",
Droemer-Knaur, 1989)

MICHAEL MARTIN

"Atheism: A Philosophical Justification", Temple University Press,
 Philadelphia, USA.
A detailed and scholarly justification of atheism.  Contains an outstanding
appendix defining terminology and usage in this (necessarily) tendentious
area.  Argues both for "negative atheism" (i.e. the "non-belief in the
existence of god(s)") and also for "positive atheism" ("the belief in the
non-existence of god(s)").  Includes great refutations of the most
challenging arguments for god; particular attention is paid to refuting
contempory theists such as Platinga and Swinburne.
541 pages. ISBN 0-87722-642-3 (hardcover; paperback also available)

"The Case Against Christianity", Temple University Press
A comprehensive critique of Christianity, in which he considers
the best contemporary defences of Christianity and (ultimately)
demonstrates that they are unsupportable and/or incoherent.
273 pages. ISBN 0-87722-767-5

JAMES TURNER

"Without God, Without Creed", The Johns Hopkins University Press, Baltimore,
 MD, USA
Subtitled "The Origins of Unbelief in America".  Examines the way in which
unbelief (whether agnostic or atheistic)  became a mainstream alternative
world-view.  Focusses on the period 1770-1900, and while considering France
and Britain the emphasis is on American, and particularly New England
developments.  "Neither a religious history of secularization or atheism,
Without God, Without Creed is, rather, the intellectual history of the fate
of a single idea, the belief that God exists." 
316 pages. ISBN (hardcover) 0-8018-2494-X (paper) 0-8018-3407-4

GEORGE SELDES (Editor)

"The great thoughts", Ballantine Books, New York, USA
A "dictionary of quotations" of a different kind, concentrating on statements
and writings which, explicitly or implicitly, present the person's philosophy
and world-view.  Includes obscure (and often suppressed) opinions from many
people.  For some popular observations, traces the way in which various
people expressed and twisted the idea over the centuries.  Quite a number of
the quotations are derived from Cardiff's "What Great Men Think of Religion"
and Noyes' "Views of Religion".
490 pages. ISBN (paper) 0-345-29887-X.

RICHARD SWINBURNE

"The Existence of God (Revised Edition)", Clarendon Paperbacks, Oxford
This book is the second volume in a trilogy that began with "The Coherence of
Theism" (1977) and was concluded with "Faith and Reason" (1981).  In this
work, Swinburne attempts to construct a series of inductive arguments for the
existence of God.  His arguments, which are somewhat tendentious and rely
upon the imputation of late 20th century western Christian values and
aesthetics to a God which is supposedly as simple as can be conceived, were
decisively rejected in Mackie's "The Miracle of Theism".  In the revised
edition of "The Existence of God", Swinburne includes an Appendix in which he
makes a somewhat incoherent attempt to rebut Mackie.

J. L. MACKIE

"The Miracle of Theism", Oxford
This (posthumous) volume contains a comprehensive review of the principal
arguments for and against the existence of God.  It ranges from the classical
philosophical positions of Descartes, Anselm, Berkeley, Hume et al, through
the moral arguments of Newman, Kant and Sidgwick, to the recent restatements
of the classical theses by Plantinga and Swinburne.  It also addresses those
positions which push the concept of God beyond the realm of the rational,
such as those of Kierkegaard, Kung and Philips, as well as "replacements for
God" such as Lelie's axiarchism.  The book is a delight to read - less
formalistic and better written than Martin's works, and refreshingly direct
when compared with the hand-waving of Swinburne.

JAMES A. HAUGHT

"Holy Horrors: An Illustrated History of Religious Murder and Madness",
 Prometheus Books
Looks at religious persecution from ancient times to the present day -- and
not only by Christians.
Library of Congress Catalog Card Number 89-64079. 1990.

NORM R. ALLEN, JR.

"African American Humanism: an Anthology"
See the listing for African Americans for Humanism above.

GORDON STEIN

"An Anthology of Atheism and Rationalism", Prometheus Books
An anthology covering a wide range of subjects, including 'The Devil, Evil
and Morality' and 'The History of Freethought'.  Comprehensive bibliography.

EDMUND D. COHEN

"The Mind of The Bible-Believer", Prometheus Books
A study of why people become Christian fundamentalists, and what effect it
has on them.

                                Net Resources

There's a small mail-based archive server at mantis.co.uk which carries
archives of old alt.atheism.moderated articles and assorted other files.  For
more information, send mail to archive-server@mantis.co.uk saying

   help
   send atheism/index

and it will mail back a reply.


mathew
�

Now you will begin preprocessing the data for this tutorial. Remove any sensitive information like names, email, or redundant parts of the text like "From: " and "\nSubject: ". Organize the information into a Pandas dataframe so it is more readable.

/* eslint-disable no-control-regex, @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-argument, @typescript-eslint/no-unsafe-assignment */

import { DataFrame } from "danfojs-node";

function preprocessText(text: string): string {
  let cleaned = text;

  // Remove emails
  cleaned = cleaned.replace(/[\w.-]+@[\w.-]+/g, "");

  // Remove names (assuming your original regex was incomplete due to formatting)
  // You can customize this pattern based on what "names" means in your context
  cleaned = cleaned.replace(/^(.*?)(?=\n)/g, ""); // naive: remove first line, often name

  // Remove "From: "
  cleaned = cleaned.replace(/From: /g, "");

  // Remove "\nSubject: "
  cleaned = cleaned.replace(/\nSubject: /g, "");

  // Remove control characters
  cleaned = cleaned.replace(/[\x00-\x1F\x7F]/g, " ");

  // Truncate to 5000 characters
  if (cleaned.length > 5000) {
    cleaned = cleaned.slice(0, 5000);
  }

  return cleaned;
}

function preprocessDataframe(df: DataFrame): DataFrame {
  const preprocessedData = df.data.values.map((d: string) => preprocessText(d));
  const preprocessedDf = new DataFrame({
    text: preprocessedData as string[],
    target: df.target.values,
  });
  /* eslint-disable @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-assignment */
  const texts = preprocessedDf.text.values as string[];
  const classNameToLabelMap: Record<string, number> = preprocessedDf.target
    .unique()
    .values.reduce((acc: Record<string, number>, className: string, index: number) => {
      acc[className] = index + 1; // Start labels from 1
      return acc;
    }, {});
  const classNames = preprocessedDf.target.values as string[];
  const labels = classNames.map((name) => classNameToLabelMap[name]);
  return new DataFrame({
    Text: texts,
    Label: labels,
    "Class Name": classNames,
  });
}

const trainDfPreprocessed = preprocessDataframe(trainDf);
const testDfPreprocessed = preprocessDataframe(testDf);
trainDfPreprocessed.head().print();
╔════════════╤═══════════════════╤═══════════════════╤═══════════════════╗
║            │ Text              │ Label             │ Class Name        ║
╟────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 0          │ Alt.Atheism FAQ…  │ 1                 │ alt.atheism       ║
╟────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 1          │ Alt.Atheism FAQ…  │ 1                 │ alt.atheism       ║
╟────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 2          │ Re: Gospel Dati…  │ 1                 │ alt.atheism       ║
╟────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 3          │ Re: university …  │ 1                 │ alt.atheism       ║
╟────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 4          │ Re: [soc.motss,…  │ 1                 │ alt.atheism       ║
╚════════════╧═══════════════════╧═══════════════════╧═══════════════════╝

Next, you will sample some of the data by taking 100 data points in the training dataset, and dropping a few of the categories to run through this tutorial. Choose the science categories to compare.

import { DataFrame } from "danfojs-node";

async function sampleData(df: DataFrame, numSamples: number, classesToKeep: string[]): Promise<DataFrame> {
  const uniqueLabels = df.Label.unique().values;
  const sampledGroups = [];
  for (const label of uniqueLabels) {
    const labelGroup = df.query(df.Label.eq(label)).resetIndex();
    const groupSize = labelGroup.shape[0];
    if (groupSize > 0) {
      const sampledGroup = await labelGroup.sample(numSamples, { seed: 42 });
      sampledGroups.push(sampledGroup);
    }
  }
  const dfSampled = danfo.concat({
    dfList: sampledGroups,
    axis: 0,
  }) as DataFrame;
  const mask = dfSampled["Class Name"].values.map((name: string) =>
    classesToKeep.some((c: string) => name.includes(c))
  );
  const dfFiltered = dfSampled.query(mask).resetIndex();

  const classNames = dfFiltered["Class Name"].unique().values;
  const classToCode: Record<string, number> = {};
  classNames.forEach((name: string, i: number) => {
    classToCode[name] = i;
  });

  const encodedLabel = dfFiltered["Class Name"].values.map((val: string) => classToCode[val]);
  dfFiltered.addColumn("Encoded Label", encodedLabel, { inplace: true });

  return dfFiltered.resetIndex();
}
const TRAIN_NUM_SAMPLES = 100;
const TEST_NUM_SAMPLES = 25;
const CLASSES_TO_KEEP = "sci";
const dfTrainFinal = await sampleData(trainDfPreprocessed, TRAIN_NUM_SAMPLES, [CLASSES_TO_KEEP]);
const dfTestFinal = await sampleData(testDfPreprocessed, TEST_NUM_SAMPLES, [CLASSES_TO_KEEP]);
/* eslint-disable @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call */

import { DataFrame } from "danfojs-node";

const trainValueCounts = dfTrainFinal["Class Name"].valueCounts() as DataFrame;
trainValueCounts.print();
╔═════════════════╤═════╗
║ sci.crypt       │ 100 ║
╟─────────────────┼─────╢
║ sci.electronics │ 100 ║
╟─────────────────┼─────╢
║ sci.med         │ 100 ║
╟─────────────────┼─────╢
║ sci.space       │ 100 ║
╚═════════════════╧═════╝
/* eslint-disable @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call */

import { DataFrame } from "danfojs-node";

const testValueCounts = dfTestFinal["Class Name"].valueCounts() as DataFrame;
testValueCounts.print();
╔═════════════════╤════╗
║ sci.crypt       │ 25 ║
╟─────────────────┼────╢
║ sci.electronics │ 25 ║
╟─────────────────┼────╢
║ sci.med         │ 25 ║
╟─────────────────┼────╢
║ sci.space       │ 25 ║
╚═════════════════╧════╝

Create the embeddings

In this section, you will see how to generate embeddings for a piece of text using the embeddings from the Gemini API. To learn more about embeddings, visit the embeddings guide.

Note

Embeddings are computed one at a time, large sample sizes can take a long time!

API changes to Embeddings

For the recent embeddings model, there is a task type parameter and the optional title (only valid with task_type=RETRIEVAL_DOCUMENT).

These parameters apply only to the recent embeddings models. The task types are:

Task Type Description
RETRIEVAL_QUERY Specifies the given text is a query in a search/retrieval setting.
RETRIEVAL_DOCUMENT Specifies the given text is a document in a search/retrieval setting.
SEMANTIC_SIMILARITY Specifies the given text will be used for Semantic Textual Similarity (STS).
CLASSIFICATION Specifies that the embeddings will be used for classification.
CLUSTERING Specifies that the embeddings will be used for clustering.
/* eslint-disable @typescript-eslint/no-unsafe-member-access, @typescript-eslint/no-unsafe-call, @typescript-eslint/no-unsafe-assignment */
import { DataFrame } from "danfojs-node";

const EMBEDDING_MODEL_ID = "models/text-embedding-004";
const BATCH_SIZE = 100;

async function addEmbeddings(df: DataFrame, textColumnName = "Text"): Promise<DataFrame> {
  const embeddings: number[][] = [];
  const display = tslab.newDisplay();
  display.text("Progress: 0%");

  for (let i = 0; i < df.shape[0]; i += BATCH_SIZE) {
    const batch = df[textColumnName].values.slice(i, i + BATCH_SIZE);
    const embeddingResponse = await ai.models.embedContent({
      model: EMBEDDING_MODEL_ID,
      contents: batch,
      config: {
        taskType: "CLASSIFICATION",
      },
    });
    const batchEmbeddings = embeddingResponse.embeddings?.map((e) => e.values ?? []) ?? [];
    embeddings.push(...batchEmbeddings);
    display.text(`Progress: ${Math.min(100, ((i + BATCH_SIZE) / df.shape[0]) * 100).toFixed(2)}%`);
  }

  const dfCopy = df.copy();
  dfCopy.addColumn("Embedding", new danfo.Series(embeddings), { inplace: true });
  return dfCopy;
}
const dfTrainWithEmbeddings = await addEmbeddings(dfTrainFinal);
Progress: 100.00%
const dfTestWithEmbeddings = await addEmbeddings(dfTestFinal);
Progress: 100.00%
dfTrainWithEmbeddings.head().print();
╔════════════╤═══════════════════╤═══════════════════╤═══════════════════╤═══════════════════╤═══════════════════╗
║            │ Text              │ Label             │ Class Name        │ Encoded Label     │ Embedding         ║
╟────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 0          │ Re: Re-inventin…  │ 12                │ sci.crypt         │ 0                 │ -0.0029816378,0…  ║
╟────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 1          │ Re: Source of r…  │ 12                │ sci.crypt         │ 0                 │ -0.020604927,0.…  ║
╟────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 2          │ Re: White House…  │ 12                │ sci.crypt         │ 0                 │ 0.00362508,0.01…  ║
╟────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 3          │ Re: How to dete…  │ 12                │ sci.crypt         │ 0                 │ -0.013612009,0.…  ║
╟────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────╢
║ 4          │ Re: Another dat…  │ 12                │ sci.crypt         │ 0                 │ -0.016524253,0.…  ║
╚════════════╧═══════════════════╧═══════════════════╧═══════════════════╧═══════════════════╧═══════════════════╝

Build a simple classification model

Here you will define a simple model with one hidden layer and a single class probability output. The prediction will correspond to the probability of a piece of text being a particular class of news. When you build your model, Keras will automatically shuffle the data points.

const tf = require("@tensorflow/tfjs-node") as typeof import("@tensorflow/tfjs-node");
import { LayersModel } from "@tensorflow/tfjs-node";

function buildClassificationModel(inputSize: number, numClasses: number): LayersModel {
  const model = tf.sequential();

  model.add(
    tf.layers.dense({
      inputShape: [inputSize],
      units: inputSize,
      activation: "relu",
    })
  );

  model.add(
    tf.layers.dense({
      units: numClasses,
      activation: "softmax",
    })
  );

  return model;
}
const embeddingSize = (dfTestWithEmbeddings.Embedding.values[0] as string).split(",").length;
const numClasses = dfTrainWithEmbeddings["Class Name"].unique().values.length;
console.log("Embedding size:", embeddingSize);
console.log("Number of classes:", numClasses);
Embedding size: 768
Number of classes: 4
const classifier = buildClassificationModel(embeddingSize, numClasses);
classifier.summary();
__________________________________________________________________________________________
Layer (type)                Input Shape               Output shape              Param #   
==========================================================================================
dense_Dense1 (Dense)        [[null,768]]              [null,768]                590592    
__________________________________________________________________________________________
dense_Dense2 (Dense)        [[null,768]]              [null,4]                  3076      
==========================================================================================
Total params: 593668
Trainable params: 593668
Non-trainable params: 0
__________________________________________________________________________________________
classifier.compile({
  loss: "sparseCategoricalCrossentropy",
  optimizer: tf.train.adam(0.001),
  metrics: ["accuracy"],
});

Train the model to classify newsgroups

Finally, you can train a simple model. Use a small number of epochs to avoid overfitting. The first epoch takes much longer than the rest, because the embeddings need to be computed only once.

import { Tensor } from "@tensorflow/tfjs-node";

const NUM_EPOCHS = 25;
const BATCH_SIZE = 32;

const yTrain = tf.tensor1d(dfTrainWithEmbeddings["Encoded Label"].values as number[]);
const xTrain = tf.tensor2d(
  dfTrainWithEmbeddings.Embedding.values.map((e: string) =>
    e.split(",").map((v: string) => parseFloat(v))
  ) as number[][]
);
const yVal = tf.tensor1d(dfTestWithEmbeddings["Encoded Label"].values as number[]);
const xVal = tf.tensor2d(
  dfTestWithEmbeddings.Embedding.values.map((e: string) => e.split(",").map((v: string) => parseFloat(v))) as number[][]
);

let bestAcc = 0;
let bestWeights: Tensor[] = [];

const restoreBestWeightsCallback = new tf.CustomCallback({
   
  onEpochEnd: async (epoch, logs) => {
    const acc = logs?.val_acc ?? 0;
    if (acc > bestAcc) {
      bestAcc = acc;
      bestWeights = classifier.getWeights().map((w) => w.clone());
    }
  },
   
  onTrainEnd: async () => {
    if (bestWeights.length > 0) {
      classifier.setWeights(bestWeights);
    }
  },
});

const earlyStopping = tf.callbacks.earlyStopping({
  monitor: "val_acc",
  patience: 3,
});

const history = await classifier.fit(xTrain, yTrain, {
  validationData: [xVal, yVal],
  batchSize: BATCH_SIZE,
  epochs: NUM_EPOCHS,
  callbacks: [earlyStopping, restoreBestWeightsCallback],
});
Epoch 1 / 25
248ms 620us/step - acc=0.355 loss=1.36 val_acc=0.300 val_loss=1.32 
Epoch 2 / 25
175ms 438us/step - acc=0.605 loss=1.24 val_acc=0.520 val_loss=1.23 
Epoch 3 / 25
182ms 454us/step - acc=0.822 loss=1.09 val_acc=0.800 val_loss=1.11 
Epoch 4 / 25
230ms 574us/step - acc=0.907 loss=0.927 val_acc=0.830 val_loss=0.970 
Epoch 5 / 25
236ms 589us/step - acc=0.942 loss=0.753 val_acc=0.740 val_loss=0.851 
Epoch 6 / 25
178ms 444us/step - acc=0.960 loss=0.592 val_acc=0.840 val_loss=0.717 
Epoch 7 / 25
222ms 556us/step - acc=0.945 loss=0.463 val_acc=0.820 val_loss=0.646 
Epoch 8 / 25
180ms 450us/step - acc=0.975 loss=0.363 val_acc=0.870 val_loss=0.573 
Epoch 9 / 25
179ms 447us/step - acc=0.982 loss=0.288 val_acc=0.890 val_loss=0.521 
Epoch 10 / 25
183ms 458us/step - acc=0.982 loss=0.232 val_acc=0.880 val_loss=0.481 
Epoch 11 / 25
184ms 459us/step - acc=0.990 loss=0.194 val_acc=0.880 val_loss=0.470 
Epoch 12 / 25
154ms 384us/step - acc=0.993 loss=0.166 val_acc=0.870 val_loss=0.454 
console.log("Training complete.");
// log best weights
console.log("Best validation accuracy:", bestAcc);
// log best weights
if (bestWeights.length > 0) {
  console.log(
    "Best weights:",
    bestWeights.map((w) => w.dataSync())
  );
}
Training complete.
Best validation accuracy: 0.8899999856948853
Best weights: [
  Float32Array(589824) [
      -0.01458553597331047,   0.04647790640592575,  -0.11790028214454651,
       0.05183210223913193,  -0.04912508651614189,  -0.06355150789022446,
       -0.0712905302643776,  0.047946009784936905,  0.015527090057730675,
      0.029393872246146202,  0.011037350632250309,  -0.08336436748504639,
      -0.06574062258005142,  -0.04845399409532547,  -0.11720632761716843,
     -0.038035374134778976, -0.053958117961883545,   0.07522750645875931,
       -0.1012653335928917,  -0.10870110988616943,   0.04959658533334732,
       0.07256724685430527,  -0.03275788202881813, -0.006865565665066242,
      0.044840097427368164,  0.003523793537169695,  -0.03280274569988251,
      0.004993577022105455,  0.017164083197712898,   0.11629833281040192,
       -0.0934731662273407,  -0.08498488366603851,   0.00523048359900713,
       0.07440733909606934,   0.13781757652759552,   0.06484566628932953,
      0.015538545325398445,  -0.08032257854938507, -0.006935241166502237,
      -0.05966002121567726,   0.04631410911679268,  -0.08069337159395218,
      0.009128236211836338,  0.013827363960444927,   0.04692847281694412,
    -0.0006923891487531364, -0.053209491074085236,  0.020538924261927605,
      -0.00419811112806201,  -0.09361930936574936,   0.08986294269561768,
       0.03333975747227669,   0.13930268585681915, -0.047056253999471664,
     -0.032480500638484955,    0.1378415822982788,   0.04952394217252731,
       0.06769244372844696,  -0.06909222900867462,   0.05191900208592415,
      -0.03349985554814339,  -0.05605528503656387, -0.007196689490228891,
      -0.13059763610363007, -0.017757786437869072,   0.04363842308521271,
       0.04512101039290428,   0.10330261290073395, -0.009136350825428963,
       0.09158995002508163, -0.016558891162276268,   0.01614260859787464,
     -0.007019644603133202,   0.11496467888355255, -0.022889919579029083,
      0.055560607463121414,  -0.12549851834774017,   0.11550126224756241,
      0.019275400787591934, -0.047948457300662994, -0.013490136712789536,
      -0.11128153651952744,  0.041187405586242676, -0.041343335062265396,
      -0.03186627849936485,  0.027621088549494743,   0.11207965761423111,
       0.05049915611743927,   -0.1461654156446457,   0.10017067193984985,
       0.05056239292025566,  -0.03779350593686104,   0.06175878643989563,
     -0.003524564439430833,  -0.11690615117549896,    0.0310731939971447,
     -0.034123972058296204,    0.1483377367258072,  0.034658048301935196,
     -0.027962299063801765,
    ... 589724 more items
  ],
  Float32Array(768) [
     -0.005864144768565893, -0.002386872423812747,   0.004244392737746239,
       0.00875105895102024,  0.008688782341778278,  0.0075296214781701565,
      0.006209527142345905,  0.001587491249665618,                      0,
      0.009258225560188293, -0.005708571057766676,  0.0021059075370430946,
       0.01021265797317028,  0.011180863715708256,   0.007244514767080545,
      -0.00600215932354331,   0.00661898497492075,   0.008360839448869228,
      0.006417177617549896,  0.004829880781471729,   0.004867929965257645,
     -0.005557484924793243, -0.004238377325236797,   0.010516848415136337,
     -0.008243365213274956,                     0,     0.0038004070520401,
     -0.005345980171114206,                     0,   0.005777123384177685,
     0.0055799223482608795, -0.005188683047890663,  0.0031341915018856525,
       0.01139699388295412,  0.013903718441724777,   0.004163961857557297,
     -0.006733772810548544,  0.004898799117654562,  -0.005556218326091766,
     -0.004022425506263971,  0.010478651151061058,   0.003950197249650955,
     0.0056711225770413876,  0.003196362406015396,  0.0014308751560747623,
    -0.0077053336426615715, -0.005621710326522589,                      0,
    -0.0060013956390321255,  0.005474470090121031,   0.003841437166556716,
     -0.006005098111927509,  0.003691073041409254,  0.0027564126066863537,
      0.009270193055272102,  0.009483403526246548,                      0,
      0.016837066039443016,  0.011547042988240719,  -0.005190846975892782,
       0.00888818595558405, -0.007888734340667725,  -0.004556449130177498,
      0.004890375770628452, -0.007367891259491444,   0.010208001360297203,
    -0.0053018988110125065,  0.008063379675149918,   0.003171218791976571,
      0.002168072387576103,                     0,                      0,
    -0.0037830756045877934,  0.006514417938888073, -0.0038512840401381254,
      0.006942862644791603, 0.0074768634513020515,   0.008995870128273964,
     -0.006729383487254381, 0.0028954725712537766,  -0.006004408001899719,
      0.008131271228194237, -0.006005174480378628,                      0,
                         0, -0.005176699720323086,  0.0030023183207958937,
      -0.00657098600640893,   0.00681687006726861,   0.004189418628811836,
     0.0016160585219040513, 0.0023266817443072796,   0.001225532148964703,
     -0.006004612892866135, 0.0009142905473709106,  -0.005663653369992971,
     -0.005407575983554125, 0.0010831031249836087,  -0.005063849966973066,
     -0.006005085539072752,
    ... 668 more items
  ],
  Float32Array(3072) [
      0.0013996026245877147, -0.006913811434060335,  -0.08260757476091385,
        0.05730537697672844,  -0.04582637920975685,  -0.09442909806966782,
      -0.011321946047246456,  0.005978808738291264,   0.08318020403385162,
        0.06318077445030212,  -0.08340343832969666,  0.014638183638453484,
       -0.17294102907180786,    0.0585409514605999,    0.0259562861174345,
        0.10531898587942123,   0.16406947374343872,  -0.13908635079860687,
        0.01616712473332882,  -0.03588026016950607,    0.1316557377576828,
       -0.06362190842628479,   0.08512906730175018,  -0.17300662398338318,
         0.1439773589372635,  -0.07102224975824356,  -0.08264817297458649,
       -0.04243004694581032,  -0.09428377449512482,    0.1034160926938057,
       -0.10022439062595367,   0.11843250691890717,    0.0880105122923851,
        0.08713158220052719,   0.04737360402941704,  0.025206370279192924,
        -0.1943245828151703,   0.03154705464839935,   0.05390729010105133,
       -0.11333061754703522,  0.017211925238370895,  0.010337409563362598,
     -0.0011633769609034061,  -0.05308622866868973,   0.07414380460977554,
        0.06186636537313461,  -0.12107450515031815,  -0.05048438534140587,
         0.1376926302909851,  -0.14492228627204895,   0.05930162966251373,
       -0.05695727840065956,   0.17072591185569763,  -0.14493508636951447,
        0.05364081636071205,  -0.07935160398483276,   0.08721017092466354,
    -0.00025920101325027645,  0.014471221715211868,  -0.13984404504299164,
       0.005904374178498983, -0.042510829865932465,  -0.01762114278972149,
       0.012967304326593876,    0.0969497412443161,   0.06904834508895874,
       0.013907638378441334,  -0.11875277012586594,  -0.18011650443077087,
        0.08649379760026932,  0.062255628407001495,  0.003597079776227474,
       0.025880560278892517,   0.06165452301502228,   0.10112625360488892,
       -0.16312748193740845,  -0.08471381664276123,   0.18138280510902405,
        0.02282988280057907,  -0.09776809066534042,   0.06602950394153595,
        -0.1385907083749771,  -0.03451548516750336,   0.05321160703897476,
       -0.07755532115697861, -0.035777896642684937, -0.011863632127642632,
       -0.05437793210148811,  0.020390290766954422,     -0.04145497828722,
      -0.003215550212189555,   -0.0760529413819313,   0.05814632028341293,
      -0.043207958340644836,   0.15311144292354584,     -0.10636006295681,
      -0.003935716114938259,  -0.08422908931970596,   0.09260941296815872,
        -0.0730542466044426,
    ... 2972 more items
  ],
  Float32Array(4) [
    -0.0017931901384145021,
    -0.005592697765678167,
    0.006856919731944799,
    -0.0004198786336928606
  ]
]

Evaluate model performance

Use tfjs Model.evaluate to get the loss and accuracy on the test dataset.

// @ts-expect-error expected Tensor<Rank> type
const evalResult = classifier.evaluate(xVal, yVal, {
  batchSize: 32,
});

console.log("Evaluation complete.");
const loss = evalResult[0].dataSync()[0];
const acc = evalResult[1].dataSync()[0];
console.log(`Test Loss: ${loss.toFixed(4)}`);
console.log(`Test Accuracy: ${acc.toFixed(4)}`);
Evaluation complete.
Test Loss: 0.5208
Test Accuracy: 0.8900

One way to evaluate your model performance is to visualize the classifier performance. Use plotHistory to see the loss and accuracy trends over the epochs.

import { History } from "@tensorflow/tfjs-node";

function plotHistory(history: History) {
  const epochs = history.epoch;

  const lossTrace = {
    x: epochs,
    y: history.history.loss,
    type: "scatter",
    mode: "lines+markers",
    name: "Train Loss",
    xaxis: "x1",
    yaxis: "y1",
    line: { color: "#1f77b4" },
  };

  const valLossTrace = {
    x: epochs,
    y: history.history.val_loss,
    type: "scatter",
    mode: "lines+markers",
    name: "Validation Loss",
    xaxis: "x1",
    yaxis: "y1",
    line: { color: "#ff7f0e" },
  };

  const accTrace = {
    x: epochs,
    // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
    y: history.history.acc ?? history.history.accuracy,
    type: "scatter",
    mode: "lines+markers",
    name: "Train Accuracy",
    xaxis: "x2",
    yaxis: "y2",
    line: { color: "#2ca02c" },
  };

  const valAccTrace = {
    x: epochs,
    // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
    y: history.history.val_acc ?? history.history.val_accuracy,
    type: "scatter",
    mode: "lines+markers",
    name: "Validation Accuracy",
    xaxis: "x2",
    yaxis: "y2",
    line: { color: "#d62728" },
  };

  const html = `
  <div style="width: 100%; height: 600px;">
    <div id="loss-acc-plot" style="width: 100%; height: 100%;"></div>
    <script src="https://cdn.jsdelivr.net/npm/plotly.js-dist@latest/plotly.min.js"></script>
    <script>
      const data = ${JSON.stringify([lossTrace, valLossTrace, accTrace, valAccTrace])};

      const layout = {
        grid: { rows: 1, columns: 2, pattern: "independent" },
        height: 600,
        width: 1200,
        margin: { t: 60, l: 60, r: 40, b: 60 },
        annotations: [
          {
            text: "Loss", x: 0.225, y: 1.12, showarrow: false,
            font: { size: 18 }, xref: "paper", yref: "paper"
          },
          {
            text: "Accuracy", x: 0.775, y: 1.12, showarrow: false,
            font: { size: 18 }, xref: "paper", yref: "paper"
          }
        ],
        xaxis: { title: "Epoch", domain: [0, 0.45] },
        yaxis: { title: "Loss" },
        xaxis2: { title: "Epoch", domain: [0.55, 1] },
        yaxis2: { title: "Accuracy" },
        legend: {
          orientation: "h",
          y: -0.2,
          x: 0.5,
          xanchor: "center",
          font: { size: 12 }
        }
      };

      Plotly.newPlot("loss-acc-plot", data, layout, { responsive: true });
    </script>
  </div>
  `;

  tslab.display.html(html);
}

plotHistory(history);

Another way to view model performance, beyond just measuring loss and accuracy is to use a confusion matrix. The confusion matrix allows you to assess the performance of the classification model beyond accuracy. You can see what misclassified points get classified as. In order to build the confusion matrix for this multi-class classification problem, get the actual values in the test set and the predicted values.

Start by generating the predicted class for each example in the validation set using Model.predict().

// @ts-expect-error expected Tensor<Rank> type
const yPredTensor = classifier.predict(xVal);
// @ts-expect-error expected Tensor<Rank> type
const yPredArray = Array.from(tf.argMax(yPredTensor, -1).dataSync());
const yValArray = Array.from(yVal.dataSync());
function computeConfusionMatrix(labels: number[], predictions: number[], numClasses: number): number[][] {
  const matrix: number[][] = Array.from({ length: numClasses }, () => Array<number>(numClasses).fill(0));
  for (let i = 0; i < labels.length; i++) {
    const actual = labels[i];
    const predicted = predictions[i];
    matrix[actual][predicted]++;
  }
  return matrix;
}

const cm = computeConfusionMatrix(yValArray, yPredArray, dfTrainWithEmbeddings["Encoded Label"].unique().values.length);
console.log("Confusion Matrix (as DataFrame):");
const cmDf = new danfo.DataFrame(cm, {
  columns: dfTrainWithEmbeddings["Class Name"].unique().values,
  index: dfTrainWithEmbeddings["Class Name"].unique().values,
});
cmDf.print();
Confusion Matrix (as DataFrame):
╔════════════╤═══════════════════╤═══════════════════╤═══════════════════╤═══════════════════╗
║            │ sci.crypt         │ sci.electronics   │ sci.med           │ sci.space         ║
╟────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────╢
║ sci.crypt  │ 25                │ 0                 │ 0                 │ 0                 ║
╟────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────╢
║ sci.electr │ 5                 │ 20                │ 0                 │ 0                 ║
║ onics      │                   │                   │                   │                   ║
╟────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────╢
║ sci.med    │ 0                 │ 2                 │ 22                │ 1                 ║
╟────────────┼───────────────────┼───────────────────┼───────────────────┼───────────────────╢
║ sci.space  │ 2                 │ 1                 │ 0                 │ 22                ║
╚════════════╧═══════════════════╧═══════════════════╧═══════════════════╧═══════════════════╝
const classNames = dfTrainWithEmbeddings["Class Name"].unique().values as string[];

const html = `
<div style="width: 100%; height: 600px;">
  <div id="conf-matrix" style="width: 100%; height: 100%;"></div>
  <script src="https://cdn.jsdelivr.net/npm/plotly.js-dist@latest/plotly.min.js"></script>
  <script>
    const trace = {
      z: ${JSON.stringify(cm)},
      x: ${JSON.stringify(classNames)},
      y: ${JSON.stringify(classNames)},
      type: "heatmap",
      colorscale: "Blues",
      showscale: true,
      hoverongaps: false
    };

    const matrixLayout = {
      title: { text: "Confusion Matrix for Newsgroup Test Dataset", font: { size: 18 } },
      xaxis: { title: "Predicted Label", tickangle: -45 },
      yaxis: { title: "True Label" },
      height: 600,
      width: 700,
      margin: { t: 80, l: 100, r: 40, b: 100 }
    };

    Plotly.newPlot("conf-matrix", [trace], matrixLayout);
  </script>
</div>
`;

tslab.display.html(html);