Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix softmax scale arg passing #903

Merged
merged 1 commit into from
Dec 24, 2024
Merged

Conversation

hanzhi713
Copy link
Member

PR #894 breaks TPU training that uses the legacy codepath since TPU kernels from jax expect sm_scale, not softmax_scale. Legacy codepath is used when segment_ids is used.

This error was caught by unit tests which can only run on TPU (i.e. not caught by CPU unit tests).

@hanzhi713 hanzhi713 requested review from ruomingp, markblee and a team as code owners December 24, 2024 00:22
@ruomingp ruomingp added this pull request to the merge queue Dec 24, 2024
Merged via the queue into apple:main with commit e4ff72c Dec 24, 2024
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants