diff --git a/src/CSSMotion.tsx b/src/CSSMotion.tsx index bb8798f..c1ec2c2 100644 --- a/src/CSSMotion.tsx +++ b/src/CSSMotion.tsx @@ -1,6 +1,10 @@ /* eslint-disable react/default-props-match-prop-types, react/no-multi-comp, react/prop-types */ import { getDOM } from '@rc-component/util/lib/Dom/findDOMNode'; -import { getNodeRef, supportRef } from '@rc-component/util/lib/ref'; +import { + composeRef, + getNodeRef, + supportNodeRef, +} from '@rc-component/util/lib/ref'; import { clsx } from 'clsx'; import * as React from 'react'; import { useRef } from 'react'; @@ -106,6 +110,10 @@ export interface CSSMotionState { prevProps?: CSSMotionProps; } +export function isRefNotConsumed(children?: CSSMotionProps['children']) { + return children?.length < 2; +} + /** * `transitionSupport` is used for none transition test case. * Default we use browser transition event support check. @@ -189,12 +197,12 @@ export function genCSSMotion(config: CSSMotionConfig) { } // We should render children when motionStyle is sync with stepStatus - return React.useMemo(() => { + const returnNode = React.useMemo(() => { if (styleReady === 'NONE') { return null; } - let motionChildren: React.ReactNode; + let motionChildren: React.ReactElement | null; const mergedProps = { ...eventProps, visible }; if (!children) { @@ -246,25 +254,20 @@ export function genCSSMotion(config: CSSMotionConfig) { ); } - // Auto inject ref if child node not have `ref` props - if ( - React.isValidElement(motionChildren) && - supportRef(motionChildren) - ) { - const originNodeRef = getNodeRef(motionChildren); - - if (!originNodeRef) { - motionChildren = React.cloneElement( - motionChildren as React.ReactElement, - { - ref: nodeRef, - }, - ); - } + return motionChildren; + }, [idRef.current]); + + if (isRefNotConsumed(children) && supportNodeRef(returnNode)) { + const originNodeRef = getNodeRef(returnNode); + + if (originNodeRef !== nodeRef) { + return React.cloneElement(returnNode as any, { + ref: composeRef(originNodeRef, nodeRef), + }); } + } - return motionChildren; - }, [idRef.current]) as React.ReactElement; + return returnNode; }, ); diff --git a/src/CSSMotionList.tsx b/src/CSSMotionList.tsx index 6eedfd1..04ca651 100644 --- a/src/CSSMotionList.tsx +++ b/src/CSSMotionList.tsx @@ -1,7 +1,7 @@ /* eslint react/prop-types: 0 */ import * as React from 'react'; import type { CSSMotionProps } from './CSSMotion'; -import OriginCSSMotion from './CSSMotion'; +import OriginCSSMotion, { isRefNotConsumed } from './CSSMotion'; import type { KeyObject } from './util/diff'; import { diffKeys, @@ -59,6 +59,10 @@ export interface CSSMotionListProps ) => React.ReactElement; } +type ChildrenWithoutRef = ( + props: Parameters[0], +) => ReturnType; + export interface CSSMotionListState { keyEntities: KeyObject[]; } @@ -174,7 +178,13 @@ export function genCSSMotionList( } }} > - {(props, ref) => children({ ...props, index }, ref)} + {isRefNotConsumed(children) + ? props => + (children as ChildrenWithoutRef)({ + ...props, + index, + }) + : (props, ref) => children({ ...props, index }, ref)} ); })} diff --git a/tests/CSSMotion.spec.tsx b/tests/CSSMotion.spec.tsx index 8cb5b32..848bfe9 100644 --- a/tests/CSSMotion.spec.tsx +++ b/tests/CSSMotion.spec.tsx @@ -942,6 +942,51 @@ describe('CSSMotion', () => { expect(ReactDOM.findDOMNode).not.toHaveBeenCalled(); }); + + it('supports existing child refs for motion end', () => { + const motionRef = React.createRef(); + const childRef = React.createRef(); + + const Demo = ({ visible }: { visible: boolean }) => ( + + {({ style, className }) => ( +
+ )} + + ); + + const { container, rerender } = render(); + + act(() => { + jest.runAllTimers(); + }); + + expect(motionRef.current.nativeElement).toBe(childRef.current); + + rerender(); + + act(() => { + jest.runAllTimers(); + }); + + fireEvent.transitionEnd(childRef.current!); + + act(() => { + jest.runAllTimers(); + }); + + expect(container.querySelector('.motion-box')).toBeFalsy(); + expect(ReactDOM.findDOMNode).not.toHaveBeenCalled(); + }); }); describe('onVisibleChanged', () => {