How to use CodeT5 with LangChain

Here is how to create a CodeT5 wrapper for LangChain, which can be used to embed code generation, translation, or analysis tasks into your LangChain application.

Create a CodeT5 Custom LLM class for LangChain

from typing import Any

from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

class CodeT5LLM(LLM):
    model_name: str
    tokenizer: Any = None
    model: AutoModelForSeq2SeqLM = None
    
    def __init__(self, model_name):
        super().__init__(model_name=model_name)
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

    def _call(self, prompt: str, stop=None) -> str:
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
        
        with torch.no_grad():
            generated_ids = self.model.generate(input_ids, max_length=128)
        
        return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    @property
    def _identifying_params(self):
        return {"model_name": self.model_name}

    def _llm_type(self):
        return "custom"

Usage of the CodeT5 LLM in your LangChain app

Then, you can use the CodeT5 LLM in your code like this:

codet5_llm = CodeT5LLM(model_name="Salesforce/codet5-base-multi-sum")

prompt = PromptTemplate(
    input_variables=["code"],
    template="{code}"
)

chain = prompt | codet5_llm
result = chain.invoke("def hello_world():\n    print('Hello, World!')")

print(result)

You may want to adjust the implementation based on your specific use case and the CodeT5 variant you’re using.

3 Comments

Leave a Reply to How to build an AI Agent with a memory | One Two BytesCancel reply