diff --git a/packages/transaction-pay-controller/CHANGELOG.md b/packages/transaction-pay-controller/CHANGELOG.md index 3254ccfca74..ddfcae1b402 100644 --- a/packages/transaction-pay-controller/CHANGELOG.md +++ b/packages/transaction-pay-controller/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- Extract shared `buildTokenData` utility and add `TokenListController` fallback for token lookups ([#7774](https://github.com/MetaMask/core/pull/7774)) + ## [12.0.2] ### Changed diff --git a/packages/transaction-pay-controller/src/actions/update-payment-token.test.ts b/packages/transaction-pay-controller/src/actions/update-payment-token.test.ts index 8b0d24dfc9e..6a94bb2962b 100644 --- a/packages/transaction-pay-controller/src/actions/update-payment-token.test.ts +++ b/packages/transaction-pay-controller/src/actions/update-payment-token.test.ts @@ -2,12 +2,8 @@ import type { TransactionMeta } from '@metamask/transaction-controller'; import { noop } from 'lodash'; import { updatePaymentToken } from './update-payment-token'; -import type { TransactionData } from '../types'; -import { - getTokenBalance, - getTokenInfo, - getTokenFiatRate, -} from '../utils/token'; +import type { TransactionData, TransactionPaymentToken } from '../types'; +import { buildTokenData } from '../utils/token'; import { getTransaction } from '../utils/transaction'; jest.mock('../utils/token'); @@ -18,18 +14,25 @@ const CHAIN_ID_MOCK = '0x1'; const FROM_MOCK = '0x456'; const TRANSACTION_ID_MOCK = '123-456'; +const PAYMENT_TOKEN_MOCK: TransactionPaymentToken = { + address: TOKEN_ADDRESS_MOCK, + balanceFiat: '2.46', + balanceHuman: '1.23', + balanceRaw: '1230000', + balanceUsd: '3.69', + chainId: CHAIN_ID_MOCK, + decimals: 6, + symbol: 'TST', +}; + describe('Update Payment Token Action', () => { - const getTokenBalanceMock = jest.mocked(getTokenBalance); - const getTokenInfoMock = jest.mocked(getTokenInfo); - const getTokenFiatRateMock = jest.mocked(getTokenFiatRate); + const buildTokenDataMock = jest.mocked(buildTokenData); const getTransactionMock = jest.mocked(getTransaction); beforeEach(() => { jest.resetAllMocks(); - getTokenInfoMock.mockReturnValue({ decimals: 6, symbol: 'TST' }); - getTokenBalanceMock.mockReturnValue('1230000'); - getTokenFiatRateMock.mockReturnValue({ fiatRate: '2.0', usdRate: '3.0' }); + buildTokenDataMock.mockReturnValue(PAYMENT_TOKEN_MOCK); getTransactionMock.mockReturnValue({ id: TRANSACTION_ID_MOCK, @@ -52,43 +55,23 @@ describe('Update Payment Token Action', () => { }, ); + expect(buildTokenDataMock).toHaveBeenCalledWith({ + chainId: CHAIN_ID_MOCK, + from: FROM_MOCK, + messenger: {}, + tokenAddress: TOKEN_ADDRESS_MOCK, + }); + expect(updateTransactionDataMock).toHaveBeenCalledTimes(1); const transactionDataMock = {} as TransactionData; updateTransactionDataMock.mock.calls[0][1](transactionDataMock); - expect(transactionDataMock.paymentToken).toStrictEqual({ - address: TOKEN_ADDRESS_MOCK, - balanceFiat: '2.46', - balanceHuman: '1.23', - balanceRaw: '1230000', - balanceUsd: '3.69', - chainId: CHAIN_ID_MOCK, - decimals: 6, - symbol: 'TST', - }); - }); - - it('throws if decimals not found', () => { - getTokenInfoMock.mockReturnValue(undefined); - - expect(() => - updatePaymentToken( - { - chainId: CHAIN_ID_MOCK, - tokenAddress: TOKEN_ADDRESS_MOCK, - transactionId: TRANSACTION_ID_MOCK, - }, - { - messenger: {} as never, - updateTransactionData: noop, - }, - ), - ).toThrow('Payment token not found'); + expect(transactionDataMock.paymentToken).toStrictEqual(PAYMENT_TOKEN_MOCK); }); - it('throws if token fiat rate not found', () => { - getTokenFiatRateMock.mockReturnValue(undefined); + it('throws if token data not found', () => { + buildTokenDataMock.mockReturnValue(undefined); expect(() => updatePaymentToken( diff --git a/packages/transaction-pay-controller/src/actions/update-payment-token.ts b/packages/transaction-pay-controller/src/actions/update-payment-token.ts index e9c7108af01..1405a90b7a9 100644 --- a/packages/transaction-pay-controller/src/actions/update-payment-token.ts +++ b/packages/transaction-pay-controller/src/actions/update-payment-token.ts @@ -1,19 +1,13 @@ import { createModuleLogger } from '@metamask/utils'; import type { Hex } from '@metamask/utils'; -import { BigNumber } from 'bignumber.js'; import type { TransactionPayControllerMessenger } from '..'; import { projectLogger } from '../logger'; import type { - TransactionPaymentToken, UpdatePaymentTokenRequest, UpdateTransactionDataCallback, } from '../types'; -import { - getTokenBalance, - getTokenFiatRate, - getTokenInfo, -} from '../utils/token'; +import { buildTokenData } from '../utils/token'; import { getTransaction } from '../utils/transaction'; const log = createModuleLogger(projectLogger, 'update-payment-token'); @@ -42,7 +36,7 @@ export function updatePaymentToken( throw new Error('Transaction not found'); } - const paymentToken = getPaymentToken({ + const paymentToken = buildTokenData({ chainId, from: transaction?.txParams.from as Hex, messenger, @@ -59,63 +53,3 @@ export function updatePaymentToken( data.paymentToken = paymentToken; }); } - -/** - * Generate the full payment token data from a token address and chain ID. - * - * @param request - The payment token request parameters. - * @param request.chainId - The chain ID. - * @param request.from - The address to get the token balance for. - * @param request.messenger - The transaction pay controller messenger. - * @param request.tokenAddress - The token address. - * @returns The payment token or undefined if the token data could not be retrieved. - */ -function getPaymentToken({ - chainId, - from, - messenger, - tokenAddress, -}: { - chainId: Hex; - from: Hex; - messenger: TransactionPayControllerMessenger; - tokenAddress: Hex; -}): TransactionPaymentToken | undefined { - const { decimals, symbol } = - getTokenInfo(messenger, tokenAddress, chainId) ?? {}; - - if (decimals === undefined || !symbol) { - return undefined; - } - - const tokenFiatRate = getTokenFiatRate(messenger, tokenAddress, chainId); - - if (tokenFiatRate === undefined) { - return undefined; - } - - const balance = getTokenBalance(messenger, from, chainId, tokenAddress); - const balanceRawValue = new BigNumber(balance); - const balanceHumanValue = new BigNumber(balance).shiftedBy(-decimals); - const balanceRaw = balanceRawValue.toFixed(0); - const balanceHuman = balanceHumanValue.toString(10); - - const balanceFiat = balanceHumanValue - .multipliedBy(tokenFiatRate.fiatRate) - .toString(10); - - const balanceUsd = balanceHumanValue - .multipliedBy(tokenFiatRate.usdRate) - .toString(10); - - return { - address: tokenAddress, - balanceFiat, - balanceHuman, - balanceRaw, - balanceUsd, - chainId, - decimals, - symbol, - }; -} diff --git a/packages/transaction-pay-controller/src/tests/messenger-mock.ts b/packages/transaction-pay-controller/src/tests/messenger-mock.ts index b6d69d4f22d..d8bdc6732e3 100644 --- a/packages/transaction-pay-controller/src/tests/messenger-mock.ts +++ b/packages/transaction-pay-controller/src/tests/messenger-mock.ts @@ -2,6 +2,7 @@ import type { TokensControllerGetStateAction } from '@metamask/assets-controller import type { TokenBalancesControllerGetStateAction } from '@metamask/assets-controllers'; import type { TokenRatesControllerGetStateAction } from '@metamask/assets-controllers'; import type { AccountTrackerControllerGetStateAction } from '@metamask/assets-controllers'; +import type { GetTokenListState } from '@metamask/assets-controllers'; import type { BridgeStatusControllerGetStateAction } from '@metamask/bridge-status-controller'; import type { MessengerActions, @@ -93,6 +94,10 @@ export function getMessengerMock({ TokensControllerGetStateAction['handler'] > = jest.fn(); + const getTokenListControllerStateMock: jest.MockedFn< + GetTokenListState['handler'] + > = jest.fn(); + const getTokenBalanceControllerStateMock: jest.MockedFn< TokenBalancesControllerGetStateAction['handler'] > = jest.fn(); @@ -197,6 +202,11 @@ export function getMessengerMock({ getTokensControllerStateMock, ); + messenger.registerActionHandler( + 'TokenListController:getState', + getTokenListControllerStateMock, + ); + messenger.registerActionHandler( 'TokenBalancesController:getState', getTokenBalanceControllerStateMock, @@ -263,6 +273,7 @@ export function getMessengerMock({ getRemoteFeatureFlagControllerStateMock, getStrategyMock, getTokenBalanceControllerStateMock, + getTokenListControllerStateMock, getTokenRatesControllerStateMock, getTokensControllerStateMock, getTransactionControllerStateMock, diff --git a/packages/transaction-pay-controller/src/utils/token.test.ts b/packages/transaction-pay-controller/src/utils/token.test.ts index ffdaa5fc353..093434f4426 100644 --- a/packages/transaction-pay-controller/src/utils/token.test.ts +++ b/packages/transaction-pay-controller/src/utils/token.test.ts @@ -4,6 +4,7 @@ import type { TokenRatesControllerState } from '@metamask/assets-controllers'; import type { Hex } from '@metamask/utils'; import { + buildTokenData, getTokenBalance, getTokenInfo, getTokenFiatRate, @@ -471,4 +472,173 @@ describe('Token Utils', () => { ]); }); }); + + describe('buildTokenData', () => { + beforeEach(() => { + getTokensControllerStateMock.mockReturnValue({ + allTokens: { + [CHAIN_ID_MOCK]: { + test123: [ + { + address: TOKEN_ADDRESS_MOCK.toLowerCase() as Hex, + decimals: DECIMALS_MOCK, + symbol: SYMBOL_MOCK, + }, + ], + }, + }, + } as never); + + getTokenBalanceControllerStateMock.mockReturnValue({ + tokenBalances: { + [FROM_MOCK.toLowerCase()]: { + [CHAIN_ID_MOCK]: { + [TOKEN_ADDRESS_MOCK]: '0x1e8480' as Hex, // 2000000 + }, + }, + }, + } as never); + + getCurrencyRateControllerStateMock.mockReturnValue({ + currencyRates: { + [TICKER_MOCK]: { + conversionRate: 2.0, + usdConversionRate: 3.0, + }, + }, + }); + + getTokenRatesControllerStateMock.mockReturnValue({ + marketData: { + [CHAIN_ID_MOCK]: { + [TOKEN_ADDRESS_MOCK]: { + price: 1.5, + }, + }, + }, + } as TokenRatesControllerState); + + findNetworkClientIdByChainIdMock.mockReturnValue(NETWORK_CLIENT_ID_MOCK); + + getNetworkClientByIdMock.mockReturnValue({ + configuration: { ticker: TICKER_MOCK }, + } as never); + }); + + it('builds token data with balance and fiat values', () => { + const result = buildTokenData({ + chainId: CHAIN_ID_MOCK, + from: FROM_MOCK, + messenger, + tokenAddress: TOKEN_ADDRESS_MOCK, + }); + + expect(result).toStrictEqual({ + address: TOKEN_ADDRESS_MOCK, + balanceFiat: '6', + balanceHuman: '2', + balanceRaw: '2000000', + balanceUsd: '9', + chainId: CHAIN_ID_MOCK, + decimals: DECIMALS_MOCK, + symbol: SYMBOL_MOCK, + }); + }); + + it('returns undefined if token info not found', () => { + getTokensControllerStateMock.mockReturnValue({ + allTokens: {}, + } as never); + + const result = buildTokenData({ + chainId: CHAIN_ID_MOCK, + from: FROM_MOCK, + messenger, + tokenAddress: TOKEN_ADDRESS_MOCK, + }); + + expect(result).toBeUndefined(); + }); + + it('returns undefined if fiat rate not found', () => { + getCurrencyRateControllerStateMock.mockReturnValue({ + currencyRates: {}, + }); + + getTokenRatesControllerStateMock.mockReturnValue({ + marketData: {}, + } as TokenRatesControllerState); + + const result = buildTokenData({ + chainId: CHAIN_ID_MOCK, + from: FROM_MOCK, + messenger, + tokenAddress: TOKEN_ADDRESS_MOCK, + }); + + expect(result).toBeUndefined(); + }); + }); + + describe('getTokenInfo - TokenListController fallback', () => { + it('falls back to TokenListController for unknown tokens', () => { + const { + messenger: testMessenger, + getTokensControllerStateMock: tokensStateMock, + getTokenListControllerStateMock, + } = getMessengerMock(); + + tokensStateMock.mockReturnValue({ + allTokens: {}, + } as never); + + getTokenListControllerStateMock.mockReturnValue({ + tokensChainsCache: { + [CHAIN_ID_MOCK]: { + data: { + [TOKEN_ADDRESS_MOCK.toLowerCase()]: { + decimals: 8, + symbol: 'LISTTOKEN', + }, + }, + }, + }, + } as never); + + const result = getTokenInfo( + testMessenger, + TOKEN_ADDRESS_MOCK, + CHAIN_ID_MOCK, + ); + + expect(result).toStrictEqual({ + decimals: 8, + symbol: 'LISTTOKEN', + }); + }); + + it('returns undefined if token not in TokenListController', () => { + const { + messenger: testMessenger, + getTokensControllerStateMock: tokensStateMock, + getTokenListControllerStateMock, + } = getMessengerMock(); + + tokensStateMock.mockReturnValue({ + allTokens: {}, + } as never); + + getTokenListControllerStateMock.mockReturnValue({ + tokensChainsCache: {}, + } as never); + + const result = getTokenInfo( + testMessenger, + TOKEN_ADDRESS_MOCK, + CHAIN_ID_MOCK, + ); + + expect(result).toBeUndefined(); + }); + }); }); diff --git a/packages/transaction-pay-controller/src/utils/token.ts b/packages/transaction-pay-controller/src/utils/token.ts index b05f0a5f32c..3596ad08b96 100644 --- a/packages/transaction-pay-controller/src/utils/token.ts +++ b/packages/transaction-pay-controller/src/utils/token.ts @@ -10,7 +10,11 @@ import { NATIVE_TOKEN_ADDRESS, POLYGON_USDCE_ADDRESS, } from '../constants'; -import type { FiatRates, TransactionPayControllerMessenger } from '../types'; +import type { + FiatRates, + TransactionPayControllerMessenger, + TransactionPaymentToken, +} from '../types'; const STABLECOINS: Record = { [CHAIN_ID_ARBITRUM]: [ARBITRUM_USDC_ADDRESS.toLowerCase() as Hex], @@ -134,6 +138,7 @@ export function getTokenInfo( const isNative = normalizedTokenAddress === getNativeToken(chainId).toLowerCase(); + // First, check user's added tokens const token = Object.values(controllerState.allTokens?.[chainId] ?? {}) .flat() .find( @@ -141,14 +146,25 @@ export function getTokenInfo( singleToken.address.toLowerCase() === normalizedTokenAddress, ); - if (!token && !isNative) { - return undefined; - } - if (token && !isNative) { return { decimals: Number(token.decimals), symbol: token.symbol }; } + // Fallback to TokenListController for known tokens not yet added by user + if (!token && !isNative) { + const tokenListInfo = getTokenInfoFromTokenList( + messenger, + tokenAddress, + chainId, + ); + + if (tokenListInfo) { + return tokenListInfo; + } + + return undefined; + } + const ticker = getTicker(chainId, messenger); if (!ticker) { @@ -263,3 +279,100 @@ function getTicker( return undefined; } } + +/** + * Get token info from TokenListController. + * This provides a fallback for tokens that aren't in the user's token list yet. + * + * @param messenger - Controller messenger. + * @param tokenAddress - Address of the token contract. + * @param chainId - Id of the chain. + * @returns The token info or undefined if not found. + */ +function getTokenInfoFromTokenList( + messenger: TransactionPayControllerMessenger, + tokenAddress: Hex, + chainId: Hex, +): { decimals: number; symbol: string } | undefined { + try { + const tokenListState = messenger.call('TokenListController:getState'); + const normalizedAddress = tokenAddress.toLowerCase() as Hex; + + // TokenListController stores tokens by chain ID and address + const tokensByChain = tokenListState.tokensChainsCache?.[chainId]; + const tokenData = tokensByChain?.data?.[normalizedAddress]; + + if (tokenData?.decimals !== undefined && tokenData?.symbol) { + return { + decimals: Number(tokenData.decimals), + symbol: tokenData.symbol, + }; + } + + return undefined; + } catch { + // TokenListController might not be available + return undefined; + } +} + +/** + * Build full token data from a token address and chain ID. + * Used by both payment token and selected token update actions. + * + * @param request - The token request parameters. + * @param request.chainId - The chain ID. + * @param request.from - The address to get the token balance for. + * @param request.messenger - The transaction pay controller messenger. + * @param request.tokenAddress - The token address. + * @returns The token data or undefined if the token data could not be retrieved. + */ +export function buildTokenData({ + chainId, + from, + messenger, + tokenAddress, +}: { + chainId: Hex; + from: Hex; + messenger: TransactionPayControllerMessenger; + tokenAddress: Hex; +}): TransactionPaymentToken | undefined { + const { decimals, symbol } = + getTokenInfo(messenger, tokenAddress, chainId) ?? {}; + + if (decimals === undefined || !symbol) { + return undefined; + } + + const tokenFiatRate = getTokenFiatRate(messenger, tokenAddress, chainId); + + if (tokenFiatRate === undefined) { + return undefined; + } + + const balance = getTokenBalance(messenger, from, chainId, tokenAddress); + const balanceRawValue = new BigNumber(balance); + const balanceHumanValue = new BigNumber(balance).shiftedBy(-decimals); + const balanceRaw = balanceRawValue.toFixed(0); + const balanceHuman = balanceHumanValue.toString(10); + + const balanceFiat = balanceHumanValue + .multipliedBy(tokenFiatRate.fiatRate) + .toString(10); + + const balanceUsd = balanceHumanValue + .multipliedBy(tokenFiatRate.usdRate) + .toString(10); + + return { + address: tokenAddress, + balanceFiat, + balanceHuman, + balanceRaw, + balanceUsd, + chainId, + decimals, + symbol, + }; +}