Posted by Lizao (Larry) Li, Software Engineer, and Rob Carver, Research Scientist, Google Research Accurate weather forecasts can have a direct impact on people’s lives, from helping make routine decisions, like what to pack for a day’s activities, to informing urgent actions, for example, protecting people in the face of hazardous weather conditions. The importance of accurate and timely weather forecasts will only increase as the climate changes. Recognizing this, we at Google have been investing in weather and climate research to help ensure that the forecasting technology of tomorrow can meet the demand for reliable weather information. Some of our recent innovations include MetNet-3, Google's high-resolution forecasts up to 24-hours into the future, and GraphCast, a weather model that can predict weather up to 10 days ahead. Weather is inherently stochastic. To quantify the uncertainty, traditional methods rely on physics-based simulation to generate an ensemble of forecasts. However, it is computationally costly to generate a large ensemble so that rare and extreme weather events can be discerned and characterized accurately. With that in mind, we are excited to announce our latest innovation designed to accelerate progress in weather forecasting, Scalable Ensemble Envelope Diffusion Sampler (SEEDS), recently published in Science Advances. SEEDS is a generative AI model that can efficiently generate ensembles of weather forecasts at scale at a small fraction of the cost of traditional physics-based forecasting models. This technology opens up novel opportunities for weather and climate science, and it represents one of the first applications to weather and climate forecasting of probabilistic diffusion models, a generative AI technology behind recent advances in media generation. The need for probabilistic forecasts: the butterfly effect American Association for the Advancement of Science meeting in Washington, D.C., MIT meteorology professor Ed Lorenz gave a talk entitled, “Does the Flap of a Butterfly's Wings in Brazil Set Off a Tornado in Texas?” which contributed to the term “butterfly effect”. He was building on his earlier, landmark 1963 paper where he examined the feasibility of “very-long-range weather prediction” and described how errors in initial conditions grow exponentially when integrated in time with numerical weather prediction models. This exponential error growth, known as chaos, results in a deterministic predictability limit that restricts the use of individual forecasts in decision making, because they do not quantify the inherent uncertainty of weather conditions. This is particularly problematic when forecasting extreme weather events, such as hurricanes, heatwaves, or floods. Recognizing the limitations of deterministic forecasts, weather agencies around the world issue probabilistic forecasts. Such forecasts are based on ensembles of deterministic forecasts, each of which is generated by including synthetic noise in the initial conditions and stochasticity in the physical processes. Leveraging the fast error growth rate in weather models, the forecasts in an ensemble are purposefully different: the initial uncertainties are tuned to generate runs that are as different as possible and the stochastic processes in the weather model introduce additional differences during the model run. The error growth is mitigated by averaging all the forecasts in the ensemble and the variability in the ensemble of forecasts quantifies the uncertainty of the weather conditions. While effective, generating these probabilistic forecasts is computationally costly. They require running highly complex numerical weather models on massive supercomputers multiple times. Consequently, many operational weather forecasts can only afford to generate ~10–50 ensemble members for each forecast cycle. This is a problem for users concerned with the likelihood of rare but high-impact weather events, which typically require much larger ensembles to assess beyond a few days. For instance, one would need a 10,000-member ensemble to forecast the likelihood of events with 1% probability of occurrence with a relative error less than 10%. Quantifying the probability of such extreme events could be useful, for example, for emergency management preparation or for energy traders. SEEDS: AI-enabled advances paper, we present the Scalable Ensemble Envelope Diffusion Sampler (SEEDS), a generative AI technology for weather forecast ensemble generation. SEEDS is based on denoising diffusion probabilistic models, a state-of-the-art generative AI method pioneered in part by Google Research. SEEDS can generate a large ensemble conditioned on as few as one or two forecasts from an operational numerical weather prediction system. The generated ensembles not only yield plausible real-weather–like forecasts but also match or exceed physics-based ensembles in skill metrics such as the rank histogram, the root-mean-squared error (RMSE), and the continuous ranked probability score (CRPS). In particular, the generated ensembles assign more accurate likelihoods to the tail of the forecast distribution, such as ±2σ and ±3σ weather events. Most importantly, the computational cost of the model is negligible when compared to the hours of computational time needed by supercomputers to make a forecast. It has a throughput of 256 ensemble members (at 2° resolution) per 3 minutes on Google Cloud TPUv3-32 instances and can easily scale to higher throughput by deploying more accelerators. SEEDS generates an order-of-magnitude more samples to in-fill distributions of weather patterns. Generating plausible weather forecasts Global Ensemble Forecast System, GEFS) for a particular date during the 2022 European heat waves. We also compare the results to the forecasts from a Gaussian model that predicts the univariate mean and standard deviation of each atmospheric field at each location, a common and computationally efficient but less sophisticated data-driven approach. This Gaussian model is meant to characterize the output of pointwise post-processing, which ignores correlations and treats each grid point as an independent random variable. In contrast, a real weather map would have detailed correlational structures. Because SEEDS directly models the joint distribution of the atmospheric state, it realistically captures both the spatial covariance and the correlation between mid-tropospheric geopotential and mean sea level pressure, both of which are closely related and are commonly used by weather forecasters for evaluation and verification of forecasts. Gradients in the mean sea level pressure are what drive winds at the surface, while gradients in mid-tropospheric geopotential create upper-level winds that move large-scale weather patterns. The generated samples from SEEDS shown in the figure below (frames Ca–Ch) display a geopotential trough west of Portugal with spatial structure similar to that found in the operational U.S. forecasts or the reanalysis based on observations. Although the Gaussian model predicts the marginal univariate distributions adequately, it fails to capture cross-field or spatial correlations. This hinders the assessment of the effects that these anomalies may have on hot air intrusions from North Africa, which can exacerbate heat waves over Europe. Stamp maps over Europe on 2022/07/14 at 0:00 UTC. The contours are for the mean sea level pressure (dashed lines mark isobars below 1010 hPa) while the heatmap depicts the geopotential height at the 500 hPa pressure level. (A) The ERA5 reanalysis, a proxy for real observations. (Ba-Bb) 2 members from the 7-day U.S. operational forecasts used as seeds to our model. (Ca-Ch) 8 samples drawn from SEEDS. (Da-Dh) 8 non-seeding members from the 7-day U.S. operational ensemble forecast. (Ea-Ed) 4 samples from a pointwise Gaussian model parameterized by the mean and variance of the entire U.S. operational ensemble. Covering extreme events more accurately SEEDS provides better statistical coverage of the 2022/07/14 European extreme heat event, denoted by the brown star . Each plot shows the values of the total column-integrated water vapor (TCVW) vs. temperature over a grid point near Lisbon, Portugal from 16,384 samples generated by our models, shown as green dots, conditioned on 2 seeds (blue squares) taken from the 7-day U.S. operational ensemble forecasts (denoted by the sparser brown triangles). The valid forecast time is 1:00 local time. The solid contour levels correspond to iso-proportions of the kernel density of SEEDS, with the outermost one encircling 95% of the mass and 11.875% between each level. Conclusion and future outlook Acknowledgements All SEEDS authors, Lizao Li, Rob Carver, Ignacio Lopez-Gomez, Fei Sha and John Anderson, co-authored this blog post, with Carla Bromberg as Program Lead. We also thank Tom Small who designed the animation. Our colleagues at Google Research have provided invaluable advice to the SEEDS work. Among them, we thank Leonardo Zepeda-Núñez, Zhong Yi Wan, Stephan Rasp, Stephan Hoyer, and Tapio Schneider for their inputs and useful discussion. We thank Tyler Russell for additional technical program management, as well as Alex Merose for data coordination and support. We also thank Cenk Gazen, Shreya Agrawal, and Jason Hickey for discussions in the early stage of the SEEDS work.
Posted by Urs Köster, Software Engineer, Google Research Time series problems are ubiquitous, from forecasting weather and traffic patterns to understanding economic trends. Bayesian approaches start with an assumption about the data's patterns (prior probability), collecting evidence (e.g., new time series data), and continuously updating that assumption to form a posterior probability distribution. Traditional Bayesian approaches like Gaussian processes (GPs) and Structural Time Series are extensively used for modeling time series data, e.g., the commonly used Mauna Loa CO2 dataset. However, they often rely on domain experts to painstakingly select appropriate model components and may be computationally expensive. Alternatives such as neural networks lack interpretability, making it difficult to understand how they generate forecasts, and don't produce reliable confidence intervals. To that end, we introduce AutoBNN, a new open-source package written in JAX. AutoBNN automates the discovery of interpretable time series forecasting models, provides high-quality uncertainty estimates, and scales effectively for use on large datasets. We describe how AutoBNN combines the interpretability of traditional probabilistic approaches with the scalability and flexibility of neural networks. AutoBNN line of research that over the past decade has yielded improved predictive accuracy by modeling time series using GPs with learned kernel structures. The kernel function of a GP encodes assumptions about the function being modeled, such as the presence of trends, periodicity or noise. With learned GP kernels, the kernel function is defined compositionally: it is either a base kernel (such as Linear, Quadratic, Periodic, Matérn or ExponentiatedQuadratic) or a composite that combines two or more kernel functions using operators such as Addition, Multiplication, or ChangePoint. This compositional kernel structure serves two related purposes. First, it is simple enough that a user who is an expert about their data, but not necessarily about GPs, can construct a reasonable prior for their time series. Second, techniques like Sequential Monte Carlo can be used for discrete searches over small structures and can output interpretable results. Bayesian neural networks (BNNs) while retaining the compositional kernel structure. A BNN is a neural network with a probability distribution over weights rather than a fixed set of weights. This induces a distribution over outputs, capturing uncertainty in the predictions. BNNs bring the following advantages over GPs: First, training large GPs is computationally expensive, and traditional training algorithms scale as the cube of the number of data points in the time series. In contrast, for a fixed width, training a BNN will often be approximately linear in the number of data points. Second, BNNs lend themselves better to GPU and TPU hardware acceleration than GP training operations. Third, compositional BNNs can be easily combined with traditional deep BNNs, which have the ability to do feature discovery. One could imagine "hybrid" architectures, in which users specify a top-level structure of Add(Linear, Periodic, Deep), and the deep BNN is left to learn the contributions from potentially high-dimensional covariate information. How might one translate a GP with compositional kernels into a BNN then? A single layer neural network will typically converge to a GP as the number of neurons (or "width") goes to infinity. More recently, researchers have discovered a correspondence in the other direction — many popular GP kernels (such as Matern, ExponentiatedQuadratic, Polynomial or Periodic) can be obtained as infinite-width BNNs with appropriately chosen activation functions and weight distributions. Furthermore, these BNNs remain close to the corresponding GP even when the width is very much less than infinite. For example, the figures below show the difference in the covariance between pairs of observations, and regression results of the true GPs and their corresponding width-10 neural network versions. Comparison of Gram matrices between true GP kernels (top row) and their width 10 neural network approximations (bottom row). Comparison of regression results between true GP kernels (top row) and their width 10 neural network approximations (bottom row). BNN analogues of the Addition and Multiplication operators over GPs, and input warping to produce periodic kernels. BNN addition is straightforwardly given by adding the outputs of the component BNNs. BNN multiplication is achieved by multiplying the activations of the hidden layers of the BNNs and then applying a shared dense layer. We are therefore limited to only multiplying BNNs with the same hidden width. Using AutoBNN package is available within Tensorflow Probability. It is implemented in JAX and uses the flax.linen neural network library. It implements all of the base kernels and operators discussed so far (Linear, Quadratic, Matern, ExponentiatedQuadratic, Periodic, Addition, Multiplication) plus one new kernel and three new operators: a OneLayer kernel, a single hidden layer ReLU BNN, a ChangePoint operator that allows smoothly switching between two kernels, a LearnableChangePoint operator which is the same as ChangePoint except position and slope are given prior distributions and can be learnt from the data, and a WeightedSum operator. WeightedSum combines two or more BNNs with learnable mixing weights, where the learnable weights follow a Dirichlet prior. By default, a flat Dirichlet distribution with concentration 1.0 is used. WeightedSums allow a "soft" version of structure discovery, i.e., training a linear combination of many possible models at once. In contrast to structure discovery with discrete structures, such as in AutoGP, this allows us to use standard gradient methods to learn structures, rather than using expensive discrete optimization. Instead of evaluating potential combinatorial structures in series, WeightedSum allows us to evaluate them in parallel. To easily enable exploration, AutoBNN defines a number of model structures that contain either top-level or internal WeightedSums. The names of these models can be used as the first parameter in any of the estimator constructors, and include things like sum_of_stumps (the WeightedSum over all the base kernels) and sum_of_shallow (which adds all possible combinations of base kernels with all operators). Illustration of the sum_of_stumps model. The bars in the top row show the amount by which each base kernel contributes, and the bottom row shows the function represented by the base kernel. The resulting weighted sum is shown on the right. M3 dataset. The six base structures were ExponentiatedQuadratic (which is the same as the Radial Basis Function kernel, or RBF for short), Matern, Linear, Quadratic, OneLayer and Periodic kernels. The figure shows the MAP estimates of their weights over an ensemble of 32 particles. All of the high likelihood particles gave a large weight to the Periodic component, low weights to Linear, Quadratic and OneLayer, and a large weight to either RBF or Matern. Parallel coordinates plot of the MAP estimates of the base kernel weights over 32 particles. The sum_of_stumps model was trained on the N374 series from the M3 dataset (insert in blue). Darker lines correspond to particles with higher likelihoods. WeightedSums as the inputs to other operators, it is possible to express rich combinatorial structures, while keeping models compact and the number of learnable weights small. As an example, we include the sum_of_products model (illustrated in the figure below) which first creates a pairwise product of two WeightedSums, and then a sum of the two products. By setting some of the weights to zero, we can create many different discrete structures. The total number of possible structures in this model is 216, since there are 16 base kernels that can be turned on or off. All these structures are explored implicitly by training just this one model. Illustration of the "sum_of_products" model. Each of the four WeightedSums have the same structure as the "sum_of_stumps" model. Periodic and either the Matern or ExponentiatedQuadratic) lead to overfitting on many datasets. To prevent this, we have defined model classes like sum_of_safe_shallow that exclude such products when performing structure discovery with WeightedSums. For training, AutoBNN provides AutoBnnMapEstimator and AutoBnnMCMCEstimator to perform MAP and MCMC inference, respectively. Either estimator can be combined with any of the six likelihood functions, including four based on normal distributions with different noise characteristics for continuous data and two based on the negative binomial distribution for count data. Result from running AutoBNN on the Mauna Loa CO2 dataset in our example colab. The model captures the trend and seasonal component in the data. Extrapolating into the future, the mean prediction slightly underestimates the actual trend, while the 95% confidence interval gradually increases. scikit-learn–inspired estimator interface: import autobnn as ab model = ab.operators.Add( bnns=(ab.kernels.PeriodicBNN(width=50), ab.kernels.LinearBNN(width=50), ab.kernels.MaternBNN(width=50))) estimator = ab.estimators.AutoBnnMapEstimator( model, 'normal_likelihood_logistic_noise', jax.random.PRNGKey(42), periods=[12]) estimator.fit(my_training_data_xs, my_training_data_ys) low, mid, high = estimator.predict_quantiles(my_training_data_xs) Conclusion AutoBNN provides a powerful and flexible framework for building sophisticated time series prediction models. By combining the strengths of BNNs and GPs with compositional kernels, AutoBNN opens a world of possibilities for understanding and forecasting complex data. We invite the community to try the colab, and leverage this library to innovate and solve real-world challenges. Acknowledgements AutoBNN was written by Colin Carroll, Thomas Colthurst, Urs Köster and Srinivas Vasudevan. We would like to thank Kevin Murphy, Brian Patton and Feras Saad for their advice and feedback.
Posted by Atilla Kiraly, Software Engineer, and Rory Pilgrim, Product Manager, Google Research Lung cancer is the leading cause of cancer-related deaths globally with 1.8 million deaths reported in 2020. Late diagnosis dramatically reduces the chances of survival. Lung cancer screening via computed tomography (CT), which provides a detailed 3D image of the lungs, has been shown to reduce mortality in high-risk populations by at least 20% by detecting potential signs of cancers earlier. In the US, screening involves annual scans, with some countries or cases recommending more or less frequent scans. The United States Preventive Services Task Force recently expanded lung cancer screening recommendations by roughly 80%, which is expected to increase screening access for women and racial and ethnic minority groups. However, false positives (i.e., incorrectly reporting a potential cancer in a cancer-free patient) can cause anxiety and lead to unnecessary procedures for patients while increasing costs for the healthcare system. Moreover, efficiency in screening a large number of individuals can be challenging depending on healthcare infrastructure and radiologist availability. At Google we have previously developed machine learning (ML) models for lung cancer detection, and have evaluated their ability to automatically detect and classify regions that show signs of potential cancer. Performance has been shown to be comparable to that of specialists in detecting possible cancer. While they have achieved high performance, effectively communicating findings in realistic environments is necessary to realize their full potential. To that end, in “Assistive AI in Lung Cancer Screening: A Retrospective Multinational Study in the US and Japan”, published in Radiology AI, we investigate how ML models can effectively communicate findings to radiologists. We also introduce a generalizable user-centric interface to help radiologists leverage such models for lung cancer screening. The system takes CT imaging as input and outputs a cancer suspicion rating using four categories (no suspicion, probably benign, suspicious, highly suspicious) along with the corresponding regions of interest. We evaluate the system’s utility in improving clinician performance through randomized reader studies in both the US and Japan, using the local cancer scoring systems (Lung-RADSs V1.1 and Sendai Score) and image viewers that mimic realistic settings. We found that reader specificity increases with model assistance in both reader studies. To accelerate progress in conducting similar studies with ML models, we have open-sourced code to process CT images and generate images compatible with the picture archiving and communication system (PACS) used by radiologists. Developing an interface to communicate model results alpha-numeric score to indicate the lung cancer risk and follow-up recommendations. When assessing patients, radiologists load the CT in their workstation to read the case, find lung nodules or lesions, and apply set guidelines to determine follow-up decisions. Our first step was to improve the previously developed ML models through additional training data and architectural improvements, including self-attention. Then, instead of targeting specific guidelines, we experimented with a complementary way of communicating AI results independent of guidelines or their particular versions. Specifically, the system output offers a suspicion rating and localization (regions of interest) for the user to consider in conjunction with their own specific guidelines. The interface produces output images directly associated with the CT study, requiring no changes to the user’s workstation. The radiologist only needs to review a small set of additional images. There is no other change to their system or interaction with the system. Example of the assistive lung cancer screening system outputs. Results for the radiologist’s evaluation are visualized on the location of the CT volume where the suspicious lesion is found. The overall suspicion is displayed at the top of the CT images. Circles highlight the suspicious lesions while squares show a rendering of the same lesion from a different perspective, called a sagittal view. prior work. The models coordinate with each other to first segment the lungs, obtain an overall assessment, locate three suspicious regions, then use the information to assign a suspicion rating to each region. The system was deployed on Google Cloud using a Google Kubernetes Engine (GKE) that pulled the images, ran the ML models, and provided results. This allows scalability and directly connects to servers where the images are stored in DICOM stores. Outline of the Google Cloud deployment of the assistive lung cancer screening system and the directional calling flow for the individual components that serve the images and compute results. Images are served to the viewer and to the system using Google Cloud services. The system is run on a Google Kubernetes Engine that pulls the images, processes them, and writes them back into the DICOM store. Reader studies area under the ROC curve (AUC) values. These were compared with and without assistance. A multi-case multi-reader study involves each case being reviewed by each reader twice, once with ML system assistance and once without. In this visualization one reader first reviews Set A without assistance (blue) and then with assistance (orange) after a wash-out period. A second reader group follows the opposite path by reading the same set of cases Set A with assistance first. Readers are randomized to these groups to remove the effect of ordering. specificity) by an absolute 5–7% compared to when they didn’t use the assistive system. This potentially means that for every 15–20 patients screened, one may be able to avoid unnecessary follow-up procedures, thus reducing their anxiety and the burden on the health care system. This can, in turn, help improve the sustainability of lung cancer screening programs, particularly as more people become eligible for screening. Reader specificity increases with ML model assistance in both the US-based and Japan-based reader studies. Specificity values were derived from reader scores from actionable findings (something suspicious was found) versus no actionable findings, compared against the true cancer outcome of the individual. Under model assistance, readers flagged fewer cancer-negative individuals for follow-up visits. Sensitivity for cancer positive individuals remained the same. Translating this into real-world impact through partnership DeepHealth, a leading AI-powered health informatics provider; and Apollo Radiology International a leading provider of Radiology services in India to explore paths for incorporating this system into future products. In addition, we are looking to help other researchers studying how best to integrate ML model results into clinical workflows by open sourcing code used for the reader study and incorporating the insights described in this blog. We hope that this will help accelerate medical imaging researchers looking to conduct reader studies for their AI models, and catalyze translational research in the field. Acknowledgements Key contributors to this project include Corbin Cunningham, Zaid Nabulsi, Ryan Najafi, Jie Yang, Charles Lau, Joseph R. Ledsam, Wenxing Ye, Diego Ardila, Scott M. McKinney, Rory Pilgrim, Hiroaki Saito, Yasuteru Shimamura, Mozziyar Etemadi, Yun Liu, David Melnick, Sunny Jansen, Nadia Harhen, David P. Nadich, Mikhail Fomitchev, Ziyad Helali, Shabir Adeel, Greg S. Corrado, Lily Peng, Daniel Tse, Shravya Shetty, Shruthi Prabhakara, Neeral Beladia, and Krish Eswaran. Thanks to Arnav Agharwal and Andrew Sellergren for their open sourcing support and Vivek Natarajan and Michael D. Howell for their feedback. Sincere appreciation also goes to the radiologists who enabled this work with their image interpretation and annotation efforts throughout the study, and Jonny Wong and Carli Sampson for coordinating the reader studies.
Posted by Yossi Matias, VP Engineering & Research, and Grey Nearing, Research Scientist, Google Research Floods are the most common natural disaster, and are responsible for roughly $50 billion in annual financial damages worldwide. The rate of flood-related disasters has more than doubled since the year 2000 partly due to climate change. Nearly 1.5 billion people, making up 19% of the world’s population, are exposed to substantial risks from severe flood events. Upgrading early warning systems to make accurate and timely information accessible to these populations can save thousands of lives per year. Driven by the potential impact of reliable flood forecasting on people’s lives globally, we started our flood forecasting effort in 2017. Through this multi-year journey, we advanced research over the years hand-in-hand with building a real-time operational flood forecasting system that provides alerts on Google Search, Maps, Android notifications and through the Flood Hub. However, in order to scale globally, especially in places where accurate local data is not available, more research advances were required. In “Global prediction of extreme floods in ungauged watersheds”, published in Nature, we demonstrate how machine learning (ML) technologies can significantly improve global-scale flood forecasting relative to the current state-of-the-art for countries where flood-related data is scarce. With these AI-based technologies we extended the reliability of currently-available global nowcasts, on average, from zero to five days, and improved forecasts across regions in Africa and Asia to be similar to what are currently available in Europe. The evaluation of the models was conducted in collaboration with the European Center for Medium Range Weather Forecasting (ECMWF). These technologies also enable Flood Hub to provide real-time river forecasts up to seven days in advance, covering river reaches across over 80 countries. This information can be used by people, communities, governments and international organizations to take anticipatory action to help protect vulnerable populations. Flood forecasting at Google launched a pilot early warning system in the Ganges-Brahmaputra river basin in India, with the hypothesis that ML could help address the challenging problem of reliable flood forecasting at scale. The pilot was further expanded the following year via the combination of an inundation model, real-time water level measurements, the creation of an elevation map and hydrologic modeling. In collaboration with academics, and, in particular, with the JKU Institute for Machine Learning we explored ML-based hydrologic models, showing that LSTM-based models could produce more accurate simulations than traditional conceptual and physics-based hydrology models. This research led to flood forecasting improvements that enabled the expansion of our forecasting coverage to include all of India and Bangladesh. We also worked with researchers at Yale University to test technological interventions that increase the reach and impact of flood warnings. Our hydrological models predict river floods by processing publicly available weather data like precipitation and physical watershed information. Such models must be calibrated to long data records from streamflow gauging stations in individual rivers. A low percentage of global river watersheds (basins) have streamflow gauges, which are expensive but necessary to supply relevant data, and it’s challenging for hydrological simulation and forecasting to provide predictions in basins that lack this infrastructure. Lower gross domestic product (GDP) is correlated with increased vulnerability to flood risks, and there is an inverse correlation between national GDP and the amount of publicly available data in a country. ML helps to address this problem by allowing a single model to be trained on all available river data and to be applied to ungauged basins where no data are available. In this way, models can be trained globally, and can make predictions for any river location. There is an inverse (log-log) correlation between the amount of publicly available streamflow data in a country and national GDP. Streamflow data from the Global Runoff Data Center. estimate uncertainty in river forecasts and showed how ML river forecast models synthesize information from multiple data sources. They demonstrated that these models can simulate extreme events reliably, even when those events are not part of the training data. In an effort to contribute to open science, in 2023 we open-sourced a community-driven dataset for large-sample hydrology in Nature Scientific Data. The river forecast model LSTMs perform well on the task of river forecasting. A diagram of the LSTM, which is a neural network that operates sequentially in time. An accessible primer can be found here. mixture density networks to produce a probabilistic forecast (i.e., predicted parameters of a probability distribution over streamflow). Specifically, the model predicts the parameters of a mixture of heavy-tailed probability density functions, called asymmetric Laplacian distributions, at each forecast time step. The result is a mixture density function, called a Countable Mixture of Asymmetric Laplacians (CMAL) distribution, which represents a probabilistic prediction of the volumetric flow rate in a particular river at a particular time. LSTM-based river forecast model architecture. Two LSTMs are applied in sequence, one ingesting historical weather data and one ingesting forecasted weather data. The model outputs are the parameters of a probability distribution over streamflow at each forecasted timestep. Input and training data Static watershed attributes representing geographical and geophysical variables: From the HydroATLAS project, including data like long-term climate indexes (precipitation, temperature, snow fractions), land cover, and anthropogenic attributes (e.g., a nighttime lights index as a proxy for human development). Historical meteorological time-series data: Used to spin up the model for one year prior to the issue time of a forecast. The data comes from NASA IMERG, NOAA CPC Global Unified Gauge-Based Analysis of Daily Precipitation, and the ECMWF ERA5-land reanalysis. Variables include daily total precipitation, air temperature, solar and thermal radiation, snowfall, and surface pressure. Forecasted meteorological time series over a seven-day forecast horizon: Used as input for the forecast LSTM. These data are the same meteorological variables listed above, and come from the ECMWF HRES atmospheric model. Training data are daily streamflow values from the Global Runoff Data Center over the time period 1980 - 2023. A single streamflow forecast model is trained using data from 5,680 diverse watershed streamflow gauges (shown below) to improve accuracy. Location of 5,680 streamflow gauges that supply training data for the river forecast model from the Global Runoff Data Center. Improving on the current state-of-the-art GloFAS version 4, the current state-of-the-art global flood forecasting system. These experiments showed that ML can provide accurate warnings earlier and over larger and more impactful events. The figure below shows the distribution of F1 scores when predicting different severity events at river locations around the world, with plus or minus 1 day accuracy. F1 scores are an average of precision and recall and event severity is measured by return period. For example, a 2-year return period event is a volume of streamflow that is expected to be exceeded on average once every two years. Our model achieves reliability scores at up to 4-day or 5-day lead times that are similar to or better, on average, than the reliability of GloFAS nowcasts (0-day lead time). Distributions of F1 scores over 2-year return period events in 2,092 watersheds globally during the time period 2014-2023 from GloFAS (blue) and our model (orange) at different lead times. On average, our model is statistically as accurate as GloFAS nowcasts (0–day lead time) up to 5 days in advance over 2-year (shown) and 1-year, 5-year, and 10-year events (not shown). paper for more information. Looking into the future Adaptation and Resilience efforts and reflects Google's commitment to address climate change while helping global communities become more resilient. We believe that AI and ML will continue to play a critical role in helping advance science and research towards climate action. We actively collaborate with several international aid organizations (e.g., the Centre for Humanitarian Data and the Red Cross) to provide actionable flood forecasts. Additionally, in an ongoing collaboration with the World Meteorological Organization (WMO) to support early warning systems for climate hazards, we are conducting a study to help understand how AI can help address real-world challenges faced by national flood forecasting agencies. While the work presented here demonstrates a significant step forward in flood forecasting, future work is needed to further expand flood forecasting coverage to more locations globally and other types of flood-related events and disasters, including flash floods and urban floods. We are looking forward to continuing collaborations with our partners in the academic and expert communities, local governments and the industry to reach these goals.
Posted by Srinivas Sunkara and Gilles Baechler, Software Engineers, Google Research Screen user interfaces (UIs) and infographics, such as charts, diagrams and tables, play important roles in human communication and human-machine interaction as they facilitate rich and interactive user experiences. UIs and infographics share similar design principles and visual language (e.g., icons and layouts), that offer an opportunity to build a single model that can understand, reason, and interact with these interfaces. However, because of their complexity and varied presentation formats, infographics and UIs present a unique modeling challenge. To that end, we introduce “ScreenAI: A Vision-Language Model for UI and Infographics Understanding”. ScreenAI improves upon the PaLI architecture with the flexible patching strategy from pix2struct. We train ScreenAI on a unique mixture of datasets and tasks, including a novel Screen Annotation task that requires the model to identify UI element information (i.e., type, location and description) on a screen. These text annotations provide large language models (LLMs) with screen descriptions, enabling them to automatically generate question-answering (QA), UI navigation, and summarization training datasets at scale. At only 5B parameters, ScreenAI achieves state-of-the-art results on UI- and infographic-based tasks (WebSRC and MoTIF), and best-in-class performance on Chart QA, DocVQA, and InfographicVQA compared to models of similar size. We are also releasing three new datasets: Screen Annotation to evaluate the layout understanding capability of the model, as well as ScreenQA Short and Complex ScreenQA for a more comprehensive evaluation of its QA capability. ScreenAI PaLI, composed of a multimodal encoder block and an autoregressive decoder. The PaLI encoder uses a vision transformer (ViT) that creates image embeddings and a multimodal encoder that takes the concatenation of the image and text embeddings as input. This flexible architecture allows ScreenAI to solve vision tasks that can be recast as text+image-to-text problems. On top of the PaLI architecture, we employ a flexible patching strategy introduced in pix2struct. Instead of using a fixed-grid pattern, the grid dimensions are selected such that they preserve the native aspect ratio of the input image. This enables ScreenAI to work well across images of various aspect ratios. The ScreenAI model is trained in two stages: a pre-training stage followed by a fine-tuning stage. First, self-supervised learning is applied to automatically generate data labels, which are then used to train ViT and the language model. ViT is frozen during the fine-tuning stage, where most data used is manually labeled by human raters. ScreenAI model architecture. Data generation publicly accessible web pages and following the programmatic exploration approach used for the RICO dataset for mobile apps. We then apply a layout annotator, based on the DETR model, that identifies and labels a wide range of UI elements (e.g., image, pictogram, button, text) and their spatial relationships. Pictograms undergo further analysis using an icon classifier capable of distinguishing 77 different icon types. This detailed classification is essential for interpreting the subtle information conveyed through icons. For icons that are not covered by the classifier, and for infographics and images, we use the PaLI image captioning model to generate descriptive captions that provide contextual information. We also apply an optical character recognition (OCR) engine to extract and annotate textual content on screen. We combine the OCR text with the previous annotations to create a detailed description of each screen. A mobile app screenshot with generated annotations that include UI elements and their descriptions, e.g., TEXT elements also contain the text content from OCR, IMAGE elements contain image captions, LIST_ITEMs contain all their child elements. LLM-based data generation PaLM 2 to generate input-output pairs in a two-step process. First, screen annotations are generated using the technique outlined above, then we craft a prompt around this schema for the LLM to create synthetic data. This process requires prompt engineering and iterative refinement to find an effective prompt. We assess the generated data's quality through human validation against a quality threshold. You only speak JSON. Do not write text that isn’t JSON. You are given the following mobile screenshot, described in words. Can you generate 5 questions regarding the content of the screenshot as well as the corresponding short answers to them? The answer should be as short as possible, containing only the necessary information. Your answer should be structured as follows: questions: [ {{question: the question, answer: the answer }}, ... ] {THE SCREEN SCHEMA} A sample prompt for QA data generation. Question answering: The model is asked to answer questions regarding the content of the screenshots, e.g., “When does the restaurant open?” Screen navigation: The model is asked to convert a natural language utterance into an executable action on a screen, e.g., “Click the search button.” Screen summarization: The model is asked to summarize the screen content in one or two sentences. Block diagram of our workflow for generating data for QA, summarization and navigation tasks using existing ScreenAI models and LLMs. Each task uses a custom prompt to emphasize desired aspects, like questions related to counting, involving reasoning, etc. LLM-generated data. Examples for screen QA, navigation and summarization. For navigation, the action bounding box is displayed in red on the screenshot. Experiments and results ChartQA, DocVQA, Multi page DocVQA, InfographicVQA, OCR VQA, Web SRC and ScreenQA. For navigation, datasets used include Referring Expressions, MoTIF, Mug, and Android in the Wild. Finally, we use Screen2Words for screen summarization and Widget Captioning for describing specific UI elements. Along with the fine-tuning datasets, we evaluate the fine-tuned ScreenAI model using three novel benchmarks: Screen Annotation: Enables the evaluation model layout annotations and spatial understanding capabilities. ScreenQA Short: A variation of ScreenQA, where its ground truth answers have been shortened to contain only the relevant information that better aligns with other QA tasks. Complex ScreenQA: Complements ScreenQA Short with more difficult questions (counting, arithmetic, comparison, and non-answerable questions) and contains screens with various aspect ratios. The fine-tuned ScreenAI model achieves state-of-the-art results on various UI and infographic-based tasks (WebSRC and MoTIF) and best-in-class performance on Chart QA, DocVQA, and InfographicVQA compared to models of similar size. ScreenAI achieves competitive performance on Screen2Words and OCR-VQA. Additionally, we report results on the new benchmark datasets introduced to serve as a baseline for further research. Comparing model performance of ScreenAI with state-of-the-art (SOTA) models of similar size. Model performance increases with size, and the performance has not saturated even at the largest size of 5B params. Conclusion Acknowledgements This project is the result of joint work with Maria Wang, Fedir Zubach, Hassan Mansoor, Vincent Etter, Victor Carbune, Jason Lin, Jindong Chen and Abhanshu Sharma. We thank Fangyu Liu, Xi Chen, Efi Kokiopoulou, Jesse Berent, Gabriel Barcik, Lukas Zilka, Oriana Riva, Gang Li,Yang Li, Radu Soricut, and Tania Bedrax-Weiss for their insightful feedback and discussions, along with Rahul Aralikatte, Hao Cheng and Daniel Kim for their support in data preparation. We also thank Jay Yagnik, Blaise Aguera y Arcas, Ewa Dominowska, David Petrou, and Matt Sharifi for their leadership, vision and support. We are very grateful toTom Small for helping us create the animation in this post.
Posted by Pooja Rao, Research Scientist, Google Research Health datasets play a crucial role in research and medical education, but it can be challenging to create a dataset that represents the real world. For example, dermatology conditions are diverse in their appearance and severity and manifest differently across skin tones. Yet, existing dermatology image datasets often lack representation of everyday conditions (like rashes, allergies and infections) and skew towards lighter skin tones. Furthermore, race and ethnicity information is frequently missing, hindering our ability to assess disparities or create solutions. To address these limitations, we are releasing the Skin Condition Image Network (SCIN) dataset in collaboration with physicians at Stanford Medicine. We designed SCIN to reflect the broad range of concerns that people search for online, supplementing the types of conditions typically found in clinical datasets. It contains images across various skin tones and body parts, helping to ensure that future AI tools work effectively for all. We've made the SCIN dataset freely available as an open-access resource for researchers, educators, and developers, and have taken careful steps to protect contributor privacy. Example set of images and metadata from the SCIN dataset. Dataset composition tanning propensity (self-reported Fitzpatrick Skin Type, i.e., sFST), and to describe the texture, duration and symptoms related to their concern. One to three dermatologists labeled each contribution with up to five dermatology conditions, along with a confidence score for each label. The SCIN dataset contains these individual labels, as well as an aggregated and weighted differential diagnosis derived from them that could be useful for model testing or training. These labels were assigned retrospectively and are not equivalent to a clinical diagnosis, but they allow us to compare the distribution of dermatology conditions in the SCIN dataset with existing datasets. The SCIN dataset contains largely allergic, inflammatory and infectious conditions while datasets from clinical sources focus on benign and malignant neoplasms. Monk Skin Tone (eMST) for the images. This allowed comparison of the skin condition and skin type distributions to those in existing dermatology datasets. Although we did not selectively target any skin types or skin tones, the SCIN dataset has a balanced Fitzpatrick skin type distribution (with more of Types 3, 4, 5, and 6) compared to similar datasets from clinical sources. Self-reported and dermatologist-estimated Fitzpatrick Skin Type distribution in the SCIN dataset compared with existing un-enriched dermatology datasets (Fitzpatrick17k, PH², SKINL2, and PAD-UFES-20). Fitzpatrick Skin Type scale was originally developed as a photo-typing scale to measure the response of skin types to UV radiation, and it is widely used in dermatology research. The Monk Skin Tone scale is a newer 10-shade scale that measures skin tone rather than skin phototype, capturing more nuanced differences between the darker skin tones. While neither scale was intended for retrospective estimation using images, the inclusion of these labels is intended to enable future research into skin type and tone representation in dermatology. For example, the SCIN dataset provides an initial benchmark for the distribution of these skin types and tones in the US population. The SCIN dataset has a high representation of women and younger individuals, likely reflecting a combination of factors. These could include differences in skin condition incidence, propensity to seek health information online, and variations in willingness to contribute to research across demographics. Crowdsourcing method research paper co-authored with investigators at Stanford Medicine. This approach empowers individuals to play an active role in healthcare research. It allows us to reach people at earlier stages of their health concerns, potentially before they seek formal care. Crucially, this method uses advertisements on web search result pages — the starting point for many people’s health journey — to connect with participants. Our results demonstrate that crowdsourcing can yield a high-quality dataset with a low spam rate. Over 97.5% of contributions were genuine images of skin conditions. After performing further filtering steps to exclude images that were out of scope for the SCIN dataset and to remove duplicates, we were able to release nearly 90% of the contributions received over the 8-month study period. Most images were sharp and well-exposed. Approximately half of the contributions include self-reported demographics, and 80% contain self-reported information relating to the skin condition, such as texture, duration, or other symptoms. We found that dermatologists’ ability to retrospectively assign a differential diagnosis depended more on the availability of self-reported information than on image quality. Dermatologist confidence in their labels (scale from 1-5) depended on the availability of self-reported demographic and symptom information. Data Use License prohibits attempts to re-identify contributors. We hope the SCIN dataset will be a helpful resource for those working to advance inclusive dermatology research, education, and AI tool development. By demonstrating an alternative to traditional dataset creation methods, SCIN paves the way for more representative datasets in areas where self-reported data or retrospective labeling is feasible. Acknowledgements We are grateful to all our co-authors Abbi Ward, Jimmy Li, Julie Wang, Sriram Lakshminarasimhan, Ashley Carrick, Bilson Campana, Jay Hartford, Pradeep Kumar S, Tiya Tiyasirisokchai, Sunny Virmani, Renee Wong, Yossi Matias, Greg S. Corrado, Dale R. Webster, Dawn Siegel (Stanford Medicine), Steven Lin (Stanford Medicine), Justin Ko (Stanford Medicine), Alan Karthikesalingam and Christopher Semturs. We also thank Yetunde Ibitoye, Sami Lachgar, Lisa Lehmann, Javier Perez, Margaret Ann Smith (Stanford Medicine), Rachelle Sico, Amit Talreja, Annisah Um’rani and Wayne Westerlind for their essential contributions to this work. Finally, we are grateful to Heather Cole-Lewis, Naama Hammel, Ivor Horn, Michael Howell, Yun Liu, and Eric Teasley for their insightful comments on the study design and manuscript.
Posted by Mark Matthews, Senior Software Engineer, and Dmitry Lagun, Research Scientist, Google Research A person's prior experience and understanding of the world generally enables them to easily infer what an object looks like in whole, even if only looking at a few 2D pictures of it. Yet the capacity for a computer to reconstruct the shape of an object in 3D given only a few images has remained a difficult algorithmic problem for years. This fundamental computer vision task has applications ranging from the creation of e-commerce 3D models to autonomous vehicle navigation. A key part of the problem is how to determine the exact positions from which images were taken, known as pose inference. If camera poses are known, a range of successful techniques — such as neural radiance fields (NeRF) or 3D Gaussian Splatting — can reconstruct an object in 3D. But if these poses are not available, then we face a difficult “chicken and egg” problem where we could determine the poses if we knew the 3D object, but we can’t reconstruct the 3D object until we know the camera poses. The problem is made harder by pseudo-symmetries — i.e., many objects look similar when viewed from different angles. For example, square objects like a chair tend to look similar every 90° rotation. Pseudo-symmetries of an object can be revealed by rendering it on a turntable from various angles and plotting its photometric self-similarity map. Self-Similarity map of a toy truck model. Left: The model is rendered on a turntable from various azimuthal angles, θ. Right: The average L2 RGB similarity of a rendering from θ with that of θ*. The pseudo-similarities are indicated by the dashed red lines. ill-posed, with naïve approaches often converging to local minima. In practice, such an approach might mistake the back view as the front view of an object, because they share a similar silhouette. Previous techniques (such as BARF or SAMURAI) side-step this problem by relying on an initial pose estimate that starts close to the global minima. But how can we approach this if those aren’t available? Methods, such as GNeRF and VMRF leverage generative adversarial networks (GANs) to overcome the problem. These techniques have the ability to artificially “amplify” a limited number of training views, aiding reconstruction. GAN techniques, however, often have complex, sometimes unstable, training processes, making robust and reliable convergence difficult to achieve in practice. A range of other successful methods, such as SparsePose or RUST, can infer poses from a limited number views, but require pre-training on a large dataset of posed images, which aren’t always available, and can suffer from “domain-gap” issues when inferring poses for different types of images. In “MELON: NeRF with Unposed Images in SO(3)”, spotlighted at 3DV 2024, we present a technique that can determine object-centric camera poses entirely from scratch while reconstructing the object in 3D. MELON (Modulo Equivalent Latent Optimization of NeRF) is one of the first techniques that can do this without initial pose camera estimates, complex training schemes or pre-training on labeled data. MELON is a relatively simple technique that can easily be integrated into existing NeRF methods. We demonstrate that MELON can reconstruct a NeRF from unposed images with state-of-the-art accuracy while requiring as few as 4–6 images of an object. MELON convolutional neural network (CNN) encoder that regresses camera poses from training images. We pass a downscaled training image to a four layer CNN that infers the camera pose. This CNN is initialized from noise and requires no pre-training. Its capacity is so small that it forces similar looking images to similar poses, providing an implicit regularization greatly aiding convergence. The second technique is a modulo loss that simultaneously considers pseudo symmetries of an object. We render the object from a fixed set of viewpoints for each training image, backpropagating the loss only through the view that best fits the training image. This effectively considers the plausibility of multiple views for each image. In practice, we find N=2 views (viewing an object from the other side) is all that’s required in most cases, but sometimes get better results with N=4 for square objects. These two techniques are integrated into standard NeRF training, except that instead of fixed camera poses, poses are inferred by the CNN and duplicated by the modulo loss. Photometric gradients back-propagate through the best-fitting cameras into the CNN. We observe that cameras generally converge quickly to globally optimal poses (see animation below). After training of the neural field, MELON can synthesize novel views using standard NeRF rendering methods. We simplify the problem by using the NeRF-Synthetic dataset, a popular benchmark for NeRF research and common in the pose-inference literature. This synthetic dataset has cameras at precisely fixed distances and a consistent “up” orientation, requiring us to infer only the polar coordinates of the camera. This is the same as an object at the center of a globe with a camera always pointing at it, moving along the surface. We then only need the latitude and longitude (2 degrees of freedom) to specify the camera pose. MELON uses a dynamically trained lightweight CNN encoder that predicts a pose for each image. Predicted poses are replicated by the modulo loss, which only penalizes the smallest L2 distance from the ground truth color. At evaluation time, the neural field can be used to generate novel views. Results peak signal-to-noise ratio (PSNR) against held out test views. We see that MELON quickly converges to the approximate poses of most cameras within the first 1,000 steps of training, and achieves a competitive PSNR of 27.5 dB after 50k steps. Convergence of MELON on a toy truck model during optimization. Left: Rendering of the NeRF. Right: Polar plot of predicted (blue x), and ground truth (red dot) cameras. Reconstruction quality comparison between ground-truth (GT) and MELON on NeRF-Synthetic scenes after 100k training steps. Noisy images novel view synthesis from extremely noisy, unposed images. We add varying amounts, σ, of white Gaussian noise to the training images. For example, the object in σ=1.0 below is impossible to make out, yet MELON can determine the pose and generate novel views of the object. Novel view synthesis from noisy unposed 128×128 images. Top: Example of noise level present in training views. Bottom: Reconstructed model from noisy training views and mean angular pose error. RawNeRF have demonstrated NeRF’s excellent de-noising capabilities with known camera poses. The fact that MELON works for noisy images of unknown camera poses so robustly was unexpected. Conclusion paper and MELON site to learn more. Acknowledgements We would like to thank our paper co-authors Axel Levy, Matan Sela, and Gordon Wetzstein, as well as Florian Schroff and Hartwig Adam for continuous help in building this technology. We also thank Matthew Brown, Ricardo Martin-Brualla and Frederic Poitevin for their helpful feedback on the paper draft. We also acknowledge the use of the computational resources at the SLAC Shared Scientific Data Facility (SDF).
Posted by Mike Schaekermann, Research Scientist, Google Research, and Ivor Horn, Chief Health Equity Officer & Director, Google Core Health equity is a major societal concern worldwide with disparities having many causes. These sources include limitations in access to healthcare, differences in clinical treatment, and even fundamental differences in the diagnostic technology. In dermatology for example, skin cancer outcomes are worse for populations such as minorities, those with lower socioeconomic status, or individuals with limited healthcare access. While there is great promise in recent advances in machine learning (ML) and artificial intelligence (AI) to help improve healthcare, this transition from research to bedside must be accompanied by a careful understanding of whether and how they impact health equity. Health equity is defined by public health organizations as fairness of opportunity for everyone to be as healthy as possible. Importantly, equity may be different from equality. For example, people with greater barriers to improving their health may require more or different effort to experience this fair opportunity. Similarly, equity is not fairness as defined in the AI for healthcare literature. Whereas AI fairness often strives for equal performance of the AI technology across different patient populations, this does not center the goal of prioritizing performance with respect to pre-existing health disparities. Health equity considerations. An intervention (e.g., an ML-based tool, indicated in dark blue) promotes health equity if it helps reduce existing disparities in health outcomes (indicated in lighter blue). Health Equity Assessment of machine Learning performance (HEAL): a framework and dermatology AI model case study”, published in The Lancet eClinicalMedicine, we propose a methodology to quantitatively assess whether ML-based health technologies perform equitably. In other words, does the ML model perform well for those with the worst health outcomes for the condition(s) the model is meant to address? This goal anchors on the principle that health equity should prioritize and measure model performance with respect to disparate health outcomes, which may be due to a number of factors that include structural inequities (e.g., demographic, social, cultural, political, economic, environmental and geographic). The health equity framework (HEAL) Framework for Health Equity Assessment of machine Learning performance (HEAL). Our guiding principle is to avoid exacerbating health inequities, and these steps help us identify disparities and assess for inequitable model performance to move towards better outcomes for all. Case study on a dermatology model prior work. This example dermatology model was trained to classify 288 skin conditions using a development dataset of 29k cases. The input to the model consists of three photos of a skin concern along with demographic information and a brief structured medical history. The output consists of a ranked list of possible matching skin conditions. Using the HEAL framework, we evaluated this model by assessing whether it prioritized performance with respect to pre-existing health outcomes. The model was designed to predict possible dermatologic conditions (from a list of hundreds) based on photos of a skin concern and patient metadata. Evaluation of the model is done using a top-3 agreement metric, which quantifies how often the top 3 output conditions match the most likely condition as suggested by a dermatologist panel. The HEAL metric is computed via the anticorrelation of this top-3 agreement with health outcome rankings. We used a dataset of 5,420 teledermatology cases, enriched for diversity in age, sex and race/ethnicity, to retrospectively evaluate the model’s HEAL metric. The dataset consisted of “store-and-forward” cases from patients of 20 years or older from primary care providers in the USA and skin cancer clinics in Australia. Based on a review of the literature, we decided to explore race/ethnicity, sex and age as potential factors of inequity, and used sampling techniques to ensure that our evaluation dataset had sufficient representation of all race/ethnicity, sex and age groups. To quantify pre-existing health outcomes for each subgroup we relied on measurements from public databases endorsed by the World Health Organization, such as Years of Life Lost (YLLs) and Disability-Adjusted Life Years (DALYs; years of life lost plus years lived with disability). HEAL metric for all dermatologic conditions across race/ethnicity subpopulations, including health outcomes (YLLs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance. (* Higher is better; measures the likelihood the model performs equitably with respect to the axes in this table.) HEAL metric for all dermatologic conditions across sexes, including health outcomes (DALYs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance. (* As above.) HEAL metrics for all cancer and non-cancer dermatologic conditions across age groups, including health outcomes (DALYs per 100,000), model performance (top-3 agreement), and rankings for health outcomes and tool performance. (* As above.) Putting things in context Pareto condition (discussed further in the paper), which restricts model changes so that outcomes for each subpopulation are either unchanged or improved compared to the status quo, and performance does not worsen for any subpopulation. The HEAL framework, in its current form, assesses the likelihood that an ML-based model prioritizes performance for subpopulations with respect to pre-existing health disparities for specific subpopulations. This differs from the goal of understanding whether ML will reduce disparities in outcomes across subpopulations in reality. Specifically, modeling improvements in outcomes requires a causal understanding of steps in the care journey that happen both before and after use of any given model. Future research is needed to address this gap. Conclusion Acknowledgements The research described here is joint work across many teams at Google. We are grateful to all our co-authors: Terry Spitz, Malcolm Pyles, Heather Cole-Lewis, Ellery Wulczyn, Stephen R. Pfohl, Donald Martin, Jr., Ronnachai Jaroensri, Geoff Keeling, Yuan Liu, Stephanie Farquhar, Qinghan Xue, Jenna Lester, Cían Hughes, Patricia Strachan, Fraser Tan, Peggy Bui, Craig H. Mermel, Lily H. Peng, Yossi Matias, Greg S. Corrado, Dale R. Webster, Sunny Virmani, Christopher Semturs, Yun Liu, and Po-Hsuan Cameron Chen. We also thank Lauren Winer, Sami Lachgar, Ting-An Lin, Aaron Loh, Morgan Du, Jenny Rizk, Renee Wong, Ashley Carrick, Preeti Singh, Annisah Um'rani, Jessica Schrouff, Alexander Brown, and Anna Iurchenko for their support of this project.
Posted by Yun Zhu and Lijuan Liu, Software Engineers, Google Research Large language model (LLM) advancements have led to a new paradigm that unifies various natural language processing (NLP) tasks within an instruction-following framework. This paradigm is exemplified by recent multi-task LLMs, such as T0, FLAN, and OPT-IML. First, multi-task data is gathered with each task following a task-specific template, where each labeled example is converted into an instruction (e.g., "Put the concepts together to form a sentence: ski, mountain, skier”) paired with a corresponding response (e.g., "Skier skis down the mountain"). These instruction-response pairs are used to train the LLM, resulting in a conditional generation model that takes an instruction as input and generates a response. Moreover, multi-task LLMs have exhibited remarkable task-wise generalization capabilities as they can address unseen tasks by understanding and solving brand-new instructions. The demonstration of the instruction-following pre-training of multi-task LLMs, e.g., FLAN. Pre-training tasks under this paradigm improves the performance for unseen tasks. FLAN-11B, T0-11B and OPT-IML-175B). As a result, operating such sizable models poses significant challenges because they demand considerable computational power and impose substantial requirements on the memory capacities of GPUs and TPUs, making their training and inference expensive and inefficient. Extensive storage is required to maintain a unique LLM copy for each downstream task. Moreover, the most powerful multi-task LLMs (e.g., FLAN-PaLM-540B) are closed-sourced, making them impossible to be adapted. However, in practical applications, harnessing a single multi-task LLM to manage all conceivable tasks in a zero-shot manner remains difficult, particularly when dealing with complex tasks, personalized tasks and those that cannot be succinctly defined using instructions. On the other hand, the size of downstream training data is usually insufficient to train a model well without incorporating rich prior knowledge. Hence, it is long desired to adapt LLMs with downstream supervision while bypassing storage, memory, and access issues. Certain parameter-efficient tuning strategies, including prompt tuning and adapters, substantially diminish storage requirements, but they still perform back-propagation through LLM parameters during the tuning process, thereby keeping their memory demands high. Additionally, some in-context learning techniques circumvent parameter tuning by integrating a limited number of supervised examples into the instruction. However, these techniques are constrained by the model's maximum input length, which permits only a few samples to guide task resolution. In “Cappy: Outperforming and Boosting Large Multi-Task LMs with a Small Scorer”, presented at NeurIPS 2023, we propose a novel approach that enhances the performance and efficiency of multi-task LLMs. We introduce a lightweight pre-trained scorer, Cappy, based on continual pre-training on top of RoBERTa with merely 360 million parameters. Cappy takes in an instruction and a candidate response as input, and produces a score between 0 and 1, indicating an estimated correctness of the response with respect to the instruction. Cappy functions either independently on classification tasks or serves as an auxiliary component for LLMs, boosting their performance. Moreover, Cappy efficiently enables downstream supervision without requiring any finetuning, which avoids the need for back-propagation through LLM parameters and reduces memory requirements. Finally, adaptation with Cappy doesn’t require access to LLM parameters as it is compatible with closed-source multi-task LLMs, such as those only accessible via WebAPIs. Cappy takes an instruction and response pair as input and outputs a score ranging from 0 to 1, indicating an estimation of the correctness of the response with respect to the instruction. Pre-training PromptSource that were used to train T0. This collection encompasses a wide range of task types, such as question answering, sentiment analysis, and summarization. Each dataset is associated with one or more templates that convert each instance from the original datasets into an instruction paired with its ground truth response. Cappy's regression modeling requires each pre-training data instance to include an instruction-response pair along with a correctness annotation for the response, so we produce a dataset with correctness annotations that range from 0 to 1. For every instance within a generation task, we leverage an existing multi-task LLM to generate multiple responses by sampling, conditioned on the given instruction. Subsequently, we assign an annotation to the pair formed by the instruction and every response, using the similarity between the response and the ground truth response of the instance. Specifically, we employ Rouge-L, a commonly-used metric for measuring overall multi-task performance that has demonstrated a strong alignment with human evaluation, to calculate this similarity as a form of weak supervision. As a result, we obtain an effective regression dataset of 160 million instances paired with correctness score annotations. The final Cappy model is the result of continuous pre-training using the regression dataset on top of the RoBERTa model. The pre-training of Cappy is conducted on Google's TPU-v4, with RedCoast, a lightweight toolkit for automating distributed training. Data augmentation with a multi-task LLM to construct a weakly supervised regression dataset for Cappy’s pre-training and fine-tuning. Applying Cappy Adapting multi-task LLMs with Cappy Downstream adaptation comparison between Cappy and approaches that rely on an LLM’s parameters, such as fine-tuning and prompt tuning. Cappy’s application enhances multi-task LLMs. Results PromptSource. We demonstrate that Cappy, with 360M parameters, outperforms OPT-175B and OPT-IML-30B, and matches the accuracy of the best existing multi-task LLMs (T0-11B and OPT-IML-175B). These findings highlight Cappy’s capabilities and parameter efficiency, which can be credited to its scoring-based pre-training strategy that integrates contrastive information by differentiating between high-quality and low-quality responses. On the contrary, previous multi-task LLMs depend exclusively on teacher-forcing training that utilizes only the ground truth responses. The overall accuracy averaged over eleven test tasks from PromptSource. “RM” refers to a pre-trained RLHF reward model. Cappy matches the best ones among existing multi-task LLMs. BIG-Bench, a set of manually curated tasks that are considered beyond the capability of many LLMs. We focus on all the 45 generation BIG-Bench tasks, specifically those that do not offer pre-established answer choices. We evaluate the performance using the Rouge-L score (representing the overall similarity between model generations and corresponding ground truths) on every test set, reporting the average score across 45 tests. In this experiment, all variants of FLAN-T5 serve as the backbone LLMs, and the foundational FLAN-T5 models are frozen. These results, shown below, suggest that Cappy enhances the performance of FLAN-T5 models by a large margin, consistently outperforming the most effective baseline achieved through sample selection using self-scoring of the LLM itself. The averaged Rouge-L score over 45 complex tasks within BIG-Bench. The x-axis refers to FLAN-T5 models of different sizes. Every dashed line represents an approach working on FLAN-T5s. Self-scoring refers to using the cross-entropy of LLM to select responses. Cappy enhances the performance of FLAN-T5 models by a large margin. Conclusion Acknowledgments Thanks to Bowen Tan, Jindong Chen, Lei Meng, Abhanshu Sharma and Ewa Dominowska for their valuable feedback. We would also like to thank Eric Xing and Zhiting Hu for their suggestions.
Posted by Bahare Fatemi and Bryan Perozzi, Research Scientists, Google Research Imagine all the things around you — your friends, tools in your kitchen, or even the parts of your bike. They are all connected in different ways. In computer science, the term graph is used to describe connections between objects. Graphs consist of nodes (the objects themselves) and edges (connections between two nodes, indicating a relationship between them). Graphs are everywhere now. The internet itself is a giant graph of websites linked together. Even the knowledge search engines use is organized in a graph-like way. Furthermore, consider the remarkable advancements in artificial intelligence — such as chatbots that can write stories in seconds, and even software that can interpret medical reports. This exciting progress is largely thanks to large language models (LLMs). New LLM technology is constantly being developed for different uses. Since graphs are everywhere and LLM technology is on the rise, in “Talk like a Graph: Encoding Graphs for Large Language Models”, presented at ICLR 2024, we present a way to teach powerful LLMs how to better reason with graph information. Graphs are a useful way to organize information, but LLMs are mostly trained on regular text. The objective is to test different techniques to see what works best and gain practical insights. Translating graphs into text that LLMs can understand is a remarkably complex task. The difficulty stems from the inherent complexity of graph structures with multiple nodes and the intricate web of edges that connect them. Our work studies how to take a graph and translate it into a format that an LLM can understand. We also design a benchmark called GraphQA to study different approaches on different graph reasoning problems and show how to phrase a graph-related problem in a way that enables the LLM to solve the graph problem. We show that LLM performance on graph reasoning tasks varies on three fundamental levels: 1) the graph encoding method, 2) the nature of the graph task itself, and 3) interestingly, the very structure of the graph considered. These findings give us clues on how to best represent graphs for LLMs. Picking the right method can make the LLM up to 60% better at graph tasks! Pictured, the process of encoding a graph as text using two different approaches and feeding the text and a question about the graph to the LLM. Graphs as text GraphQA. Think of GraphQA as an exam designed to evaluate powerful LLMs on graph-specific problems. We want to see how well LLMs can understand and solve problems that involve graphs in different setups. To create a comprehensive and realistic exam for LLMs, we don’t just use one type of graph, we use a mix of graphs ensuring breadth in the number of connections. This is mainly because different graph types make solving such problems easier or harder. This way, GraphQA can help expose biases in how an LLM thinks about the graphs, and the whole exam gets closer to a realistic setup that LLMs might encounter in the real world. Overview of our framework for reasoning with graphs using LLMs. Erdős-Rényi, scale-free networks, Barabasi-Albert model, and stochastic block model, as well as simpler graph structures like paths, complete graphs, and star graphs, providing a diverse set of data for training. When working with graphs, we also need to find ways to ask graph-related questions that LLMs can understand. Prompting heuristics are different strategies for doing this. Let's break down the common ones: Zero-shot: simply describe the task ("Is there a cycle in this graph?") and tell the LLM to go for it. No examples provided. Few-shot: This is like giving the LLM a mini practice test before the real deal. We provide a few example graph questions and their correct answers. Chain-of-Thought: Here, we show the LLM how to break down a problem step-by-step with examples. The goal is to teach it to generate its own "thought process" when faced with new graphs. Zero-CoT: Similar to CoT, but instead of training examples, we give the LLM a simple prompt, like "Let's think step-by-step," to trigger its own problem-solving breakdown. BAG (build a graph): This is specifically for graph tasks. We add the phrase "Let's build a graph..." to the description, helping the LLM focus on the graph structure. We explored different ways to translate graphs into text that LLMs can work with. Our key questions were: Node encoding: How do we represent individual nodes? Options tested include simple integers, common names (people, characters), and letters. Edge encoding: How do we describe the relationships between nodes? Methods involved parenthesis notation, phrases like "are friends", and symbolic representations like arrows. Various node and edge encodings were combined systematically. This led to functions like the ones in the following figure: Examples of graph encoding functions used to encode graphs via text. Analysis and results How LLMs handle graph tasks LLMs struggle: On most of these basic tasks, LLMs did not do much better than a random guess. Encoding matters significantly: How we represent the graph as text has a great effect on LLM performance. The "incident" encoding excelled for most of the tasks in general. Our results are summarized in the following chart. Comparison of various graph encoder functions based on their accuracy on different graph tasks. The main conclusion from this figure is that the graph encoding functions matter significantly. Bigger is (usually) better PaLM 2. Here is a summary of our findings: In general, bigger models did better on graph reasoning tasks. It seems like the extra parameters gave them space to learn more complex patterns. Oddly, size didn't matter as much for the “edge existence” task (finding out if two nodes in a graph are connected). Even the biggest LLM couldn't consistently beat a simple baseline solution on the cycle check problem (finding out if a graph contains a cycle or not). This shows LLMs still have room to improve with certain graph tasks. Effect of model capacity on graph reasoning task for PaLM 2-XXS, XS, S, and L. Do different graph shapes confuse LLMs Samples of graphs generated with different graph generators from GraphQA. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively. Comparing different graph generators on different graph tasks. The main observation here is that graph structure has a significant impact on the LLM’s performance. ER, BA, SBM, and SFN refers to Erdős–Rényi, Barabási–Albert, Stochastic Block Model, and Scale-Free Network respectively. Conclusion How to translate the graph to text: how we represent the graph as text significantly influences LLM performance. The incident encoding excelled for most of the tasks in general.. Task type: Certain types of graph questions tend to be harder for LLMs, even with a good translation from graph to text. Graph structure: Surprisingly, the "shape" of the graph that on which we do inference (dense with connections, sparse, etc.) influences how well an LLM does. This study revealed key insights about how to prepare graphs for LLMs. The right encoding techniques can significantly boost an LLM's accuracy on graph problems (ranging from around 5% to over 60% improvement). Our new benchmark, GraphQA, will help drive further research in this area. Acknowledgements We would like to express our gratitude to our co-author, Jonathan Halcrow, for his valuable contributions to this work. We express our sincere gratitude to Anton Tsitsulin, Dustin Zelle, Silvio Lattanzi, Vahab Mirrokni, and the entire graph mining team at Google Research, for their insightful comments, thorough proofreading, and constructive feedback which greatly enhanced the quality of our work. We would also like to extend special thanks to Tom Small for creating the animation used in this post.
Posted by Zilong Wang, Student Researcher, and Chen-Yu Lee, Research Scientist, Cloud AI Team People use tables every day to organize and interpret complex information in a structured, easily accessible format. Due to the ubiquity of such tables, reasoning over tabular data has long been a central topic in natural language processing (NLP). Researchers in this field have aimed to leverage language models to help users answer questions, verify statements, and analyze data based on tables. However, language models are trained over large amounts of plain text, so the inherently structured nature of tabular data can be difficult for language models to fully comprehend and utilize. Recently, large language models (LLMs) have achieved outstanding performance across diverse natural language understanding (NLU) tasks by generating reliable reasoning chains, as shown in works like Chain-of-Thought and Least-to-Most. However, the most suitable way for LLMs to reason over tabular data remains an open question. In “Chain-of-Table: Evolving Tables in the Reasoning Chain for Table Understanding”, we propose a framework to tackle table understanding tasks, where we train LLMs to outline their reasoning step by step, updating a given table iteratively to reflect each part of a thought process, akin to how people solve the table-based problems. This enables the LLM to transform the table into simpler and more manageable segments so that it can understand and analyze each part of the table in depth. This approach has yielded significant improvements and achieved new state-of-the-art results on the WikiTQ, TabFact, and FeTaQA benchmarks. The figure below shows the high-level overview of the proposed Chain-of-Table and other methods. Given a complex table where a cyclist’s nationality and name are in the same cell, (a) generic, multi-step reasoning is unable to provide the correct answer (b) program-aided reasoning generates and executes programs (e.g., SQL queries) to deliver the answer, but falls short in accurately addressing the question. In contrast, (c) Chain-of-Table iteratively samples a chain of operations that effectively transform the complex table into a version specifically tailored to the question. Chain-of-Table in-context learning to iteratively generate operations and to update the table to represent its reasoning chain over tabular data. This enables LLMs to dynamically plan the next operation based on the results of previous ones. This continuous evolution of the table forms a chain, which provides a more structured and clear representation of the reasoning process for a given problem and enables more accurate and reliable predictions from the LLM. For example, when asked, “Which actor has the most NAACP image awards?” the Chain-of-Table framework prompts an LLM to generate tabular operations mirroring tabular reasoning processes. It first identifies the relevant columns. Then, it aggregates rows based on shared content. Finally, it reorders the aggregated results to yield a final table that clearly answers the posed question. These operations transform the table to align with the question presented. To balance performance with computational expense on large tables, we construct the operation chain according to a subset of tabular rows.. Meanwhile, the step-by-step operations reveal the underlying reasoning process through the display of intermediate results from the tabular operations, fostering enhanced interpretability and understanding. Illustration of the tabular reasoning process in Chain-of-Table. This iterative process involves dynamically planning an operation chain and accurately storing intermediate results in the transformed tables. These intermediate tables serve as a tabular thought process that can guide the LLM to land to the correct answer more reliably. The question Q: “Which country had the most cyclists finish in the top 3?” The operation history chain: f_add_col(Country) and f_select_row(1, 2, 3). The latest intermediate table T: the transformed intermediate table. By providing the triplet (T, Q, chain) in the prompt, the LLM can observe the previous tabular reasoning process and select the next operation from the operation pool to complete the reasoning chain step by step. Illustration of how Chain-of-Table selects the next operation from the operation pool and generates the arguments for the operation.(a) Chain-of-Table samples the next operation from the operation pool. (b) It takes the selected operation as input and generates its arguments. f is determined, in the second stage, we need to generate the arguments. As above, Chain-of-Table considers three components in the prompt as shown in the figure: (1) the question, (2) the selected operation and its required arguments, and (3) the latest intermediate table. For instance, when the operation f_group_by is selected, it requires a header name as its argument. The LLM selects a suitable header within the table. Equipped with the selected operation and the generated arguments, Chain-of-Table executes the operation and constructs a new intermediate table for the following reasoning. Chain-of-Table iterates the previous two stages to plan the next operation and generate the required arguments. During this process, we create an operation chain acting as a proxy for the tabular reasoning steps. These operations generate intermediate tables presenting the results of each step to the LLM. Consequently, the output table contains comprehensive information about the intermediate phases of tabular reasoning. In our final stage, we employ this output table in formulating the final query and prompt the LLM along with the question for the final answer. Experimental setup PaLM 2-S and GPT 3.5 as the backbone LLMs and conduct the experiments on three public table understanding benchmarks: WikiTQ, TabFact, and FeTaQA. WikiTQ and FeTaQA are datasets for table-based question answering. TabFact is a table-based fact verification benchmark. In this blogpost, we will focus on the results on WikiTQ and TabFact. We compare Chain-of-Table with the generic reasoning methods (e.g., End-to-End QA, Few-Shot QA, and Chain-of-Thought) and the program-aided methods (e.g., Text-to-SQL, Binder, and Dater). More accurate answers PaLM 2 and GPT 3.5. This is attributed to the dynamically sampled operations and the informative intermediate tables. Understanding results on WikiTQ and TabFact with PaLM 2 and GPT 3.5 compared with various models. Better robustness on harder questions PaLM 2 on WikiTQ. Performance of Chain-of-Thought, Dater, and the proposed Chain-of-Table on WikiTQ for questions that require an operation chain of varying lengths. Our proposed atomic operations significantly improve performance over generic and program-aided reasoning counterparts. Chain-of-Thought, and up to 7.9% compared with Dater. Moreover, the performance of Chain-of-Table declines gracefully with increasing number of operations compared to other baseline methods, exhibiting only a minimal decrease when the number of operations increases from four to five. Better robustness with larger tables WikiTQ into three groups based on token number: small (4000 tokens). We then compare Chain-of-Table with Dater and Binder, the two latest and strongest baselines. Performance of Binder, Dater, and the proposed Chain-of-Table on small (4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.) Performance of Binder, Dater, and the proposed Chain-of-Table on small (4000 tokens) tables from WikiTQ. We observe that the performance decreases with larger input tables while Chain-of-Table diminishes gracefully, achieving significant improvements over competing methods. (As above, underlined text denotes the second-best performance; bold denotes the best performance.) As anticipated, the performance decreases with larger input tables, as models are required to reason through longer contexts. Nevertheless, the performance of the proposed Chain-of-Table diminishes gracefully, achieving a significant 10+% improvement over the second best competing method when dealing with large tables. This demonstrates the efficacy of the reasoning chain in handling long tabular inputs. Conclusion Acknowledgements This research was conducted by Zilong Wang, Hao Zhang, Chun-Liang Li, Julian Martin Eisenschlos, Vincent Perot, Zifeng Wang, Lesly Miculicich, Yasuhisa Fujii, Jingbo Shang, Chen-Yu Lee, Tomas Pfister. Thanks to Chih-Kuan Yeh and Sergey Ioffe for their valuable feedback.
Posted by Dave Steiner, Clinical Research Scientist, Google Health, and Rory Pilgrim, Product Manager, Google Research There’s a worldwide shortage of access to medical imaging expert interpretation across specialties including radiology, dermatology and pathology. Machine learning (ML) technology can help ease this burden by powering tools that enable doctors to interpret these images more accurately and efficiently. However, the development and implementation of such ML tools are often limited by the availability of high-quality data, ML expertise, and computational resources. One way to catalyze the use of ML for medical imaging is via domain-specific models that utilize deep learning (DL) to capture the information in medical images as compressed numerical vectors (called embeddings). These embeddings represent a type of pre-learned understanding of the important features in an image. Identifying patterns in the embeddings reduces the amount of data, expertise, and compute needed to train performant models as compared to working with high-dimensional data, such as images, directly. Indeed, these embeddings can be used to perform a variety of downstream tasks within the specialized domain (see animated graphic below). This framework of leveraging pre-learned understanding to solve related tasks is similar to that of a seasoned guitar player quickly learning a new song by ear. Because the guitar player has already built up a foundation of skill and understanding, they can quickly pick up the patterns and groove of a new song. Path Foundation is used to convert a small dataset of (image, label) pairs into (embedding, label) pairs. These pairs can then be used to train a task-specific classifier using a linear probe, (i.e., a lightweight linear classifier) as represented in this graphic, or other types of models using the embeddings as input. Once the linear probe is trained, it can be used to make predictions on embeddings from new images. These predictions can be compared to ground truth information in order to evaluate the linear probe's performance. Derm Foundation and Path Foundation. This follows on the strong response we’ve already received from researchers using the CXR Foundation embedding tool for chest radiographs and represents a portion of our expanding research offerings across multiple medical-specialized modalities. These embedding tools take an image as input and produce a numerical vector (the embedding) that is specialized to the domains of dermatology and digital pathology images, respectively. By running a dataset of chest X-ray, dermatology, or pathology images through the respective embedding tool, researchers can obtain embeddings for their own images, and use these embeddings to quickly develop new models for their applications. Path Foundation Domain-specific optimization and diverse evaluation of self-supervised models for histopathology”, we showed that self-supervised learning (SSL) models for pathology images outperform traditional pre-training approaches and enable efficient training of classifiers for downstream tasks. This effort focused on hematoxylin and eosin (H&E) stained slides, the principal tissue stain in diagnostic pathology that enables pathologists to visualize cellular features under a microscope. The performance of linear classifiers trained using the output of the SSL models matched that of prior DL models trained on orders of magnitude more labeled data. Due to substantial differences between digital pathology images and “natural image” photos, this work involved several pathology-specific optimizations during model training. One key element is that whole-slide images (WSIs) in pathology can be 100,000 pixels across (thousands of times larger than typical smartphone photos) and are analyzed by experts at multiple magnifications (zoom levels). As such, the WSIs are typically broken down into smaller tiles or patches for computer vision and DL applications. The resulting images are information dense with cells or tissue structures distributed throughout the frame instead of having distinct semantic objects or foreground vs. background variations, thus creating unique challenges for robust SSL and feature extraction. Additionally, physical (e.g., cutting) and chemical (e.g., fixing and staining) processes used to prepare the samples can influence image appearance dramatically. Taking these important aspects into consideration, pathology-specific SSL optimizations included helping the model learn stain-agnostic features, generalizing the model to patches from multiple magnifications, augmenting the data to mimic scanning and image post processing, and custom data balancing to improve input heterogeneity for SSL training. These approaches were extensively evaluated using a broad set of benchmark tasks involving 17 different tissue types over 12 different tasks. Utilizing the vision transformer (ViT-S/16) architecture, Path Foundation was selected as the best performing model from the optimization and evaluation process described above (and illustrated in the figure below). This model thus provides an important balance between performance and model size to enable valuable and scalable use in generating embeddings over the many individual image patches of large pathology WSIs. SSL training with pathology-specific optimizations for Path Foundation. AUROC) compared to traditional pre-training on natural images (ImageNet-21k). This includes evaluation for tasks such as metastatic breast cancer detection in lymph nodes, prostate cancer grading, and breast cancer grading, among others. Path Foundation embeddings significantly outperform traditional ImageNet embeddings as evaluated by linear probing across multiple evaluation tasks in histopathology. Derm Foundation Derm Foundation is an embedding tool derived from our research in applying DL to interpret images of dermatology conditions and includes our recent work that adds improvements to generalize better to new datasets. Due to its dermatology-specific pre-training it has a latent understanding of features present in images of skin conditions and can be used to quickly develop models to classify skin conditions. The model underlying the API is a BiT ResNet-101x3 trained in two stages. The first pre-training stage uses contrastive learning, similar to ConVIRT, to train on a large number of image-text pairs from the internet. In the second stage, the image component of this pre-trained model is then fine-tuned for condition classification using clinical datasets, such as those from teledermatology services. Unlike histopathology images, dermatology images more closely resemble the real-world images used to train many of today's computer vision models. However, for specialized dermatology tasks, creating a high-quality model may still require a large dataset. With Derm Foundation, researchers can use their own smaller dataset to retrieve domain-specific embeddings, and use those to build smaller models (e.g., linear classifiers or other small non-linear models) that enable them to validate their research or product ideas. To evaluate this approach, we trained models on a downstream task using teledermatology data. Model training involved varying dataset sizes (12.5%, 25%, 50%, 100%) to compare embedding-based linear classifiers against fine-tuning. The modeling variants considered were: A linear classifier on frozen embeddings from BiT-M (a standard pre-trained image model) Fine-tuned version of BiT-M with an extra dense layer for the downstream task A linear classifier on frozen embeddings from the Derm Foundation API Fine-tuned version of the model underlying the Derm Foundation API with an extra layer for the downstream task We found that models built on top of the Derm Foundation embeddings for dermatology-related tasks achieved significantly higher quality than those built solely on embeddings or fine tuned from BiT-M. This advantage was found to be most pronounced for smaller training dataset sizes. These results demonstrate that the Derm Foundation tooI can serve as a useful starting point to accelerate skin-related modeling tasks. We aim to enable other researchers to build on the underlying features and representations of dermatology that the model has learned. Access Path and Derm Foundation Derm Foundation Access Form Path Foundation Access Form After gaining access to each tool, you can use the API to retrieve embeddings from dermatology images or digital pathology images stored in Google Cloud. Approved users who are just curious to see the model and embeddings in action can use the provided example Colab notebooks to train models using public data for classifying six common skin conditions or identifying tumors in histopathology patches. We look forward to seeing the range of use-cases these tools can unlock. Acknowledgements We would like to thank the many collaborators who helped make this work possible including Yun Liu, Can Kirmizi, Fereshteh Mahvar, Bram Sterling, Arman Tajback, Kenneth Philbrik, Arnav Agharwal, Aurora Cheung, Andrew Sellergren, Boris Babenko, Basil Mustafa, Jan Freyberg, Terry Spitz, Yuan Liu, Pinal Bavishi, Ayush Jain, Amit Talreja, Rajeev Rikhye, Abbi Ward, Jeremy Lai, Faruk Ahmed, Supriya Vijay,Tiam Jaroensri, Jessica Loo, Saurabh Vyawahare, Saloni Agarwal, Ellery Wulczyn, Jonathan Krause, Fayaz Jamil, Tom Small, Annisah Um'rani, Lauren Winer, Sami Lachgar, Yossi Matias, Greg Corrado, and Dale Webster.
Posted by Amirkeivan Mohtashami, Research Intern, and Florian Hartmann, Software Engineer, Google Research Large language models (LLMs) have significantly improved the state of the art for solving tasks specified using natural language, often reaching performance close to that of people. As these models increasingly enable assistive agents, it could be beneficial for them to learn effectively from each other, much like people do in social settings, which would allow LLM-based agents to improve each other’s performance. To discuss the learning processes of humans, Bandura and Walters described the concept of social learning in 1977, outlining different models of observational learning used by people. One common method of learning from others is through a verbal instruction (e.g., from a teacher) that describes how to engage in a particular behavior. Alternatively, learning can happen through a live model by mimicking a live example of the behavior. Given the success of LLMs mimicking human communication, in our paper “Social Learning: Towards Collaborative Learning with Large Language Models”, we investigate whether LLMs are able to learn from each other using social learning. To this end, we outline a framework for social learning in which LLMs share knowledge with each other in a privacy-aware manner using natural language. We evaluate the effectiveness of our framework on various datasets, and propose quantitative methods that measure privacy in this setting. In contrast to previous approaches to collaborative learning, such as common federated learning approaches that often rely on gradients, in our framework, agents teach each other purely using natural language. Social learning for LLMs spam detection in short text messages (SMS), solving grade school math problems, and answering questions based on a given text. A visualization of the social learning process: A teacher model provides instructions or few-shot examples to a student model without sharing its private data. few-shot learning. With this in mind, we provide human-labeled examples of a task that enables the teacher model to teach it to a student. One of the main use cases of social learning arises when these examples cannot be directly shared with the student due, for example, to privacy concerns. To illustrate this, let’s look at a hypothetical example for a spam detection task. A teacher model is located on device where some users volunteer to mark incoming messages they receive as either “spam” or “not spam”. This is useful data that could help train a student model to differentiate between spam and not spam, but sharing personal messages with other users is a breach of privacy and should be avoided. To prevent this, a social learning process can transfer the knowledge from the teacher model to the student so it learns what spam messages look like without needing to share the user’s personal text messages. We investigate the effectiveness of this social learning approach by analogy with the established human social learning theory that we discussed above. In these experiments, we use PaLM 2-S models for both the teacher and the student. A systems view of social learning: At training time, multiple teachers teach the student. At inference time, the student is using what it learned from the teachers. Synthetic examples The 8 generated examples perform as well as the original data for several tasks (see our paper). Generating 16 instead of just 8 examples further reduces the performance gap relative to the original examples. paper, we additionally look into aggregation methods for selecting good subsets of examples to use. Synthetic instruction Lambada, GSM8k, and Random Insertion, providing synthetic examples performs better than providing generated instructions, whereas in the other tasks generated instruction obtains a higher accuracy. This observation suggests that the choice of the teaching model depends on the task at hand, similar to how the most effective method for teaching people varies by task. Depending on the task, generating instructions can work better than generating new examples. Memorization of the private examples Secret Sharer, a popular method for quantifying to what extent a model memorizes its training data, and adapted it to the social learning setting. We picked this method since it had previously been used for evaluating memorization in federated learning. To apply the Secret Sharer method to social learning, we design “canary” data points such that we can concretely measure how much the training process memorized them. These data points are included in the datasets used by teachers to generate new examples. After the social learning process completes, we can then measure how much more confident the student is in the secret data points the teacher used, compared to similar ones that were not shared even with the teachers. In our analysis, discussed in detail in the paper, we use canary examples that include names and codes. Our results show that the student is only slightly more confident in the canaries the teacher used. In contrast, when the original data points are directly shared with the student, the confidence in the included canaries is much higher than in the held-out set. This supports the conclusion that the teacher does indeed use its data to teach without simply copying it over. Conclusion and next steps Acknowledgements We would like to acknowledge and thank Matt Sharifi, Sian Gooding, Lukas Zilka, and Blaise Aguera y Arcas, who are all co-authors on the paper. Furthermore, we would like to thank Victor Cărbune, Zachary Garrett, Tautvydas Misiunas, Sofia Neata and John Platt for their feedback, which greatly improved the paper. We’d also like to thank Tom Small for creating the animated figure.
Posted by Omar Benjelloun, Software Engineer, Google Research, and Peter Mattson, Software Engineer, Google Core ML and President, MLCommons Association Machine learning (ML) practitioners looking to reuse existing datasets to train an ML model often spend a lot of time understanding the data, making sense of its organization, or figuring out what subset to use as features. So much time, in fact, that progress in the field of ML is hampered by a fundamental obstacle: the wide variety of data representations. ML datasets cover a broad range of content types, from text and structured data to images, audio, and video. Even within datasets that cover the same types of content, every dataset has a unique ad hoc arrangement of files and data formats. This challenge reduces productivity throughout the entire ML development process, from finding the data to training the model. It also impedes development of badly needed tooling for working with datasets. There are general purpose metadata formats for datasets such as schema.org and DCAT. However, these formats were designed for data discovery rather than for the specific needs of ML data, such as the ability to extract and combine data from structured and unstructured sources, to include metadata that would enable responsible use of the data, or to describe ML usage characteristics such as defining training, test and validation sets. Today, we're introducing Croissant, a new metadata format for ML-ready datasets. Croissant was developed collaboratively by a community from industry and academia, as part of the MLCommons effort. The Croissant format doesn't change how the actual data is represented (e.g., image or text file formats) — it provides a standard way to describe and organize it. Croissant builds upon schema.org, the de facto standard for publishing structured data on the Web, which is already used by over 40M datasets. Croissant augments it with comprehensive layers for ML relevant metadata, data resources, data organization, and default ML semantics. In addition, we are announcing support from major tools and repositories: Today, three widely used collections of ML datasets — Kaggle, Hugging Face, and OpenML — will begin supporting the Croissant format for the datasets they host; the Dataset Search tool lets users search for Croissant datasets across the Web; and popular ML frameworks, including TensorFlow, PyTorch, and JAX, can load Croissant datasets easily using the TensorFlow Datasets (TFDS) package. Croissant specification of the format, a set of example datasets, an open source Python library to validate, consume and generate Croissant metadata, and an open source visual editor to load, inspect and create Croissant dataset descriptions in an intuitive way. Supporting Responsible AI (RAI) was a key goal of the Croissant effort from the start. We are also releasing the first version of the Croissant RAI vocabulary extension, which augments Croissant with key properties needed to describe important RAI use cases such as data life cycle management, data labeling, participatory data, ML safety and fairness evaluation, explainability, and compliance. Why a shared format for ML data? What can Croissant do today? The Croissant ecosystem: Users can Search for Croissant datasets, download them from major repositories, and easily load them into their favorite ML frameworks. They can create, inspect and modify Croissant metadata using the Croissant editor. Google Dataset Search, which offers a Croissant filter. HuggingFace Kaggle OpenML With a Croissant dataset, it is possible to: Ingest data easily via TensorFlow Datasets for use in popular ML frameworks like TensorFlow, PyTorch, and JAX. Inspect and modify the metadata using the Croissant editor UI (github). To publish a Croissant dataset, users can: Use the Croissant editor UI (github) to generate a large portion of Croissant metadata automatically by analyzing the data the user provides, and to fill important metadata fields such as RAI properties. Publish the Croissant information as part of their dataset Web page to make it discoverable and reusable. Publish their data in one of the repositories that support Croissant, such as Kaggle, HuggingFace and OpenML, and automatically generate Croissant metadata. Future direction join us in contributing to the effort. Acknowledgements Croissant was developed by the Dataset Search, Kaggle and TensorFlow Datasets teams from Google, as part of an MLCommons community working group, which also includes contributors from these organizations: Bayer, cTuning Foundation, DANS-KNAW, Dotphoton, Harvard, Hugging Face, Kings College London, LIST, Meta, NASA, North Carolina State University, Open Data Institute, Open University of Catalonia, Sage Bionetworks, and TU Eindhoven.
Posted by Kate Weber and Shannon Leon, Google Research, Quantum AI Team Today the 2024 March Meeting of the American Physical Society (APS) kicks off in Minneapolis, MN. A premier conference on topics ranging across physics and related fields, APS 2024 brings together researchers, students, and industry professionals to share their discoveries and build partnerships with the goal of realizing fundamental advances in physics-related sciences and technology. This year, Google has a strong presence at APS with a booth hosted by the Google Quantum AI team, 50+ talks throughout the conference, and participation in conference organizing activities, special sessions and events. Attending APS 2024 in person? Come visit Google’s Quantum AI booth to learn more about the exciting work we’re doing to solve some of the field’s most interesting challenges. @GoogleAI X (Twitter) account to find out about Google booth activities (e.g., demos and Q&A sessions).--> You can learn more about the latest cutting edge work we are presenting at the conference along with our schedule of booth events below (Googlers listed in bold). Organizing Committee Aaron Szasz Booth Activities This schedule is subject to change. Please visit the Google Quantum AI booth for more information. Crumble: A prototype interactive tool for visualizing QEC circuits Presenter: Matt McEwen Tue, Mar 5 | 11:00 AM CST Qualtran: An open-source library for effective resource estimation of fault tolerant algorithms Presenter: Tanuj Khattar Tue, Mar 5 | 2:30 PM CST Qualtran: An open-source library for effective resource estimation of fault tolerant algorithms Presenter: Tanuj Khattar Thu, Mar 7 | 11:00 AM CST $5M XPRIZE / Google Quantum AI competition to accelerate quantum applications Q&A Presenter: Ryan Babbush Thu, Mar 7 | 11:00 AM CST Talks Monday Certifying highly-entangled states from few single-qubit measurements Presenter: Hsin-Yuan Huang Author: Hsin-Yuan Huang Session A45: New Frontiers in Machine Learning Quantum Physics Toward high-fidelity analog quantum simulation with superconducting qubits Presenter: Trond Andersen Authors: Trond I Andersen, Xiao Mi, Amir H Karamlou, Nikita Astrakhantsev, Andrey Klots, Julia Berndtsson, Andre Petukhov, Dmitry Abanin, Lev B Ioffe, Yu Chen, Vadim Smelyanskiy, Pedram Roushan Session A51: Applications on Noisy Quantum Hardware I Measuring circuit errors in context for surface code circuits Presenter: Dripto M Debroy Authors: Dripto M Debroy, Jonathan A Gross, Élie Genois, Zhang Jiang Session B50: Characterizing Noise with QCVV Techniques Quantum computation of stopping power for inertial fusion target design I: Physics overview and the limits of classical algorithms Presenter: Andrew D. Baczewski Authors: Nicholas C. Rubin, Dominic W. Berry, Alina Kononov, Fionn D. Malone, Tanuj Khattar, Alec White, Joonho Lee, Hartmut Neven, Ryan Babbush, Andrew D. Baczewski Session B51: Heterogeneous Design for Quantum Applications Link to Paper Quantum computation of stopping power for inertial fusion target design II: Physics overview and the limits of classical algorithms Presenter: Nicholas C. Rubin Authors: Nicholas C. Rubin, Dominic W. Berry, Alina Kononov, Fionn D. Malone, Tanuj Khattar, Alec White, Joonho Lee, Hartmut Neven, Ryan Babbush, Andrew D. Baczewski Session B51: Heterogeneous Design for Quantum Applications Link to Paper Calibrating Superconducting Qubits: From NISQ to Fault Tolerance Presenter: Sabrina S Hong Author: Sabrina S Hong Session B56: From NISQ to Fault Tolerance Measurement and feedforward induced entanglement negativity transition Presenter: Ramis Movassagh Authors: Alireza Seif, Yu-Xin Wang, Ramis Movassagh, Aashish A. Clerk Session B31: Measurement Induced Criticality in Many-Body Systems Link to Paper Effective quantum volume, fidelity and computational cost of noisy quantum processing experiments Presenter: Salvatore Mandra Authors: Kostyantyn Kechedzhi, Sergei V Isakov, Salvatore Mandra, Benjamin Villalonga, X. Mi, Sergio Boixo, Vadim Smelyanskiy Session B52: Quantum Algorithms and Complexity Link to Paper Accurate thermodynamic tables for solids using Machine Learning Interaction Potentials and Covariance of Atomic Positions Presenter: Mgcini K Phuthi Authors: Mgcini K Phuthi, Yang Huang, Michael Widom, Ekin D Cubuk, Venkat Viswanathan Session D60: Machine Learning of Molecules and Materials: Chemical Space and Dynamics Tuesday IN-Situ Pulse Envelope Characterization Technique (INSPECT) Presenter: Zhang Jiang Authors: Zhang Jiang, Jonathan A Gross, Élie Genois Session F50: Advanced Randomized Benchmarking and Gate Calibration Characterizing two-qubit gates with dynamical decoupling Presenter: Jonathan A Gross Authors: Jonathan A Gross, Zhang Jiang, Élie Genois, Dripto M Debroy, Ze-Pei Cian*, Wojciech Mruczkiewicz Session F50: Advanced Randomized Benchmarking and Gate Calibration Statistical physics of regression with quadratic models Presenter: Blake Bordelon Authors: Blake Bordelon, Cengiz Pehlevan, Yasaman Bahri Session EE01: V: Statistical and Nonlinear Physics II Improved state preparation for first-quantized simulation of electronic structure Presenter: William J Huggins Authors: William J Huggins, Oskar Leimkuhler, Torin F Stetina, Birgitta Whaley Session G51: Hamiltonian Simulation Controlling large superconducting quantum processors Presenter: Paul V. Klimov Authors: Paul V. Klimov, Andreas Bengtsson, Chris Quintana, Alexandre Bourassa, Sabrina Hong, Andrew Dunsworth, Kevin J. Satzinger, William P. Livingston, Volodymyr Sivak, Murphy Y. Niu, Trond I. Andersen, Yaxing Zhang, Desmond Chik, Zijun Chen, Charles Neill, Catherine Erickson, Alejandro Grajales Dau, Anthony Megrant, Pedram Roushan, Alexander N. Korotkov, Julian Kelly, Vadim Smelyanskiy, Yu Chen, Hartmut Neven Session G30: Commercial Applications of Quantum Computing Link to Paper Gaussian boson sampling: Determining quantum advantage Presenter: Peter D Drummond Authors: Peter D Drummond, Alex Dellios, Ned Goodman, Margaret D Reid, Ben Villalonga Session G50: Quantum Characterization, Verification, and Validation II Attention to complexity III: learning the complexity of random quantum circuit states Presenter: Hyejin Kim Authors: Hyejin Kim, Yiqing Zhou, Yichen Xu, Chao Wan, Jin Zhou, Yuri D Lensky, Jesse Hoke, Pedram Roushan, Kilian Q Weinberger, Eun-Ah Kim Session G50: Quantum Characterization, Verification, and Validation II Balanced coupling in superconducting circuits Presenter: Daniel T Sank Authors: Daniel T Sank, Sergei V Isakov, Mostafa Khezri, Juan Atalaya Session K48: Strongly Driven Superconducting Systems Resource estimation of Fault Tolerant algorithms using Qᴜᴀʟᴛʀᴀɴ Presenter: Tanuj Khattar Author: Tanuj Khattar, Matthew Harrigan, Fionn D. Malone, Nour Yosri, Nicholas C. Rubin Session K49: Algorithms and Implementations on Near-Term Quantum Computers Wednesday Discovering novel quantum dynamics with superconducting qubits Presenter: Pedram Roushan Author: Pedram Roushan Session M24: Analog Quantum Simulations Across Platforms Deciphering Tumor Heterogeneity in Triple-Negative Breast Cancer: The Crucial Role of Dynamic Cell-Cell and Cell-Matrix Interactions Presenter: Susan Leggett Authors: Susan Leggett, Ian Wong, Celeste Nelson, Molly Brennan, Mohak Patel, Christian Franck, Sophia Martinez, Joe Tien, Lena Gamboa, Thomas Valentin, Amanda Khoo, Evelyn K Williams Session M27: Mechanics of Cells and Tissues II Toward implementation of protected charge-parity qubits Presenter: Abigail Shearrow Authors: Abigail Shearrow, Matthew Snyder, Bradley G Cole, Kenneth R Dodge, Yebin Liu, Andrey Klots, Lev B Ioffe, Britton L Plourde, Robert McDermott Session N48: Unconventional Superconducting Qubits Electronic capacitance in tunnel junctions for protected charge-parity qubits Presenter: Bradley G Cole Authors: Bradley G Cole, Kenneth R Dodge, Yebin Liu, Abigail Shearrow, Matthew Snyder, Andrey Klots, Lev B Ioffe, Robert McDermott, B.L.T. Plourde Session N48: Unconventional Superconducting Qubits Overcoming leakage in quantum error correction Presenter: Kevin C. Miao Authors: Kevin C. Miao, Matt McEwen, Juan Atalaya, Dvir Kafri, Leonid P. Pryadko, Andreas Bengtsson, Alex Opremcak, Kevin J. Satzinger, Zijun Chen, Paul V. Klimov, Chris Quintana, Rajeev Acharya, Kyle Anderson, Markus Ansmann, Frank Arute, Kunal Arya, Abraham Asfaw, Joseph C. Bardin, Alexandre Bourassa, Jenna Bovaird, Leon Brill, Bob B. Buckley, David A. Buell, Tim Burger, Brian Burkett, Nicholas Bushnell, Juan Campero, Ben Chiaro, Roberto Collins, Paul Conner, Alexander L. Crook, Ben Curtin, Dripto M. Debroy, Sean Demura, Andrew Dunsworth, Catherine Erickson, Reza Fatemi, Vinicius S. Ferreira, Leslie Flores Burgos, Ebrahim Forati, Austin G. Fowler, Brooks Foxen, Gonzalo Garcia, William Giang, Craig Gidney, Marissa Giustina, Raja Gosula, Alejandro Grajales Dau, Jonathan A. Gross, Michael C. Hamilton, Sean D. Harrington, Paula Heu, Jeremy Hilton, Markus R. Hoffmann, Sabrina Hong, Trent Huang, Ashley Huff, Justin Iveland, Evan Jeffrey, Zhang Jiang, Cody Jones, Julian Kelly, Seon Kim, Fedor Kostritsa, John Mark Kreikebaum, David Landhuis, Pavel Laptev, Lily Laws, Kenny Lee, Brian J. Lester, Alexander T. Lill, Wayne Liu, Aditya Locharla, Erik Lucero, Steven Martin, Anthony Megrant, Xiao Mi, Shirin Montazeri, Alexis Morvan, Ofer Naaman, Matthew Neeley, Charles Neill, Ani Nersisyan, Michael Newman, Jiun How Ng, Anthony Nguyen, Murray Nguyen, Rebecca Potter, Charles Rocque, Pedram Roushan, Kannan Sankaragomathi, Christopher Schuster, Michael J. Shearn, Aaron Shorter, Noah Shutty, Vladimir Shvarts, Jindra Skruzny, W. Clarke Smith, George Sterling, Marco Szalay, Douglas Thor, Alfredo Torres, Theodore White, Bryan W. K. Woo, Z. Jamie Yao, Ping Yeh, Juhwan Yoo, Grayson Young, Adam Zalcman, Ningfeng Zhu, Nicholas Zobrist, Hartmut Neven, Vadim Smelyanskiy, Andre Petukhov, Alexander N. Korotkov, Daniel Sank, Yu Chen Session N51: Quantum Error Correction Code Performance and Implementation I Link to Paper Modeling the performance of the surface code with non-uniform error distribution: Part 1 Presenter: Yuri D Lensky Authors: Yuri D Lensky, Volodymyr Sivak, Kostyantyn Kechedzhi, Igor Aleiner Session N51: Quantum Error Correction Code Performance and Implementation I Modeling the performance of the surface code with non-uniform error distribution: Part 2 Presenter: Volodymyr Sivak Authors: Volodymyr Sivak, Michael Newman, Cody Jones, Henry Schurkus, Dvir Kafri, Yuri D Lensky, Paul Klimov, Kostyantyn Kechedzhi, Vadim Smelyanskiy Session N51: Quantum Error Correction Code Performance and Implementation I Highly optimized tensor network contractions for the simulation of classically challenging quantum computations Presenter: Benjamin Villalonga Author: Benjamin Villalonga Session Q51: Co-evolution of Quantum Classical Algorithms Teaching modern quantum computing concepts using hands-on open-source software at all levels Presenter: Abraham Asfaw Author: Abraham Asfaw Session Q61: Teaching Quantum Information at All Levels II Thursday New circuits and an open source decoder for the color code Presenter: Craig Gidney Authors: Craig Gidney, Cody Jones Session S51: Quantum Error Correction Code Performance and Implementation II Link to Paper Performing Hartree-Fock many-body physics calculations with large language models Presenter: Eun-Ah Kim Authors: Eun-Ah Kim, Haining Pan, Nayantara Mudur, William Taranto, Subhashini Venugopalan, Yasaman Bahri, Michael P Brenner Session S18: Data Science, AI and Machine Learning in Physics I New methods for reducing resource overhead in the surface code Presenter: Michael Newman Authors: Craig M Gidney, Michael Newman, Peter Brooks, Cody Jones Session S51: Quantum Error Correction Code Performance and Implementation II Link to Paper Challenges and opportunities for applying quantum computers to drug design Presenter: Raffaele Santagati Authors: Raffaele Santagati, Alan Aspuru-Guzik, Ryan Babbush, Matthias Degroote, Leticia Gonzalez, Elica Kyoseva, Nikolaj Moll, Markus Oppel, Robert M. Parrish, Nicholas C. Rubin, Michael Streif, Christofer S. Tautermann, Horst Weiss, Nathan Wiebe, Clemens Utschig-Utschig Session S49: Advances in Quantum Algorithms for Near-Term Applications Link to Paper Dispatches from Google's hunt for super-quadratic quantum advantage in new applications Presenter: Ryan Babbush Author: Ryan Babbush Session T45: Recent Advances in Quantum Algorithms Qubit as a reflectometer Presenter: Yaxing Zhang Authors: Yaxing Zhang, Benjamin Chiaro Session T48: Superconducting Fabrication, Packaging, & Validation Random-matrix theory of measurement-induced phase transitions in nonlocal Floquet quantum circuits Presenter: Aleksei Khindanov Authors: Aleksei Khindanov, Lara Faoro, Lev Ioffe, Igor Aleiner Session W14: Measurement-Induced Phase Transitions Continuum limit of finite density many-body ground states with MERA Presenter: Subhayan Sahu Authors: Subhayan Sahu, Guifré Vidal Session W58: Extreme-Scale Computational Science Discovery in Fluid Dynamics and Related Disciplines II Dynamics of magnetization at infinite temperature in a Heisenberg spin chain Presenter: Eliott Rosenberg Authors: Eliott Rosenberg, Trond Andersen, Rhine Samajdar, Andre Petukhov, Jesse Hoke*, Dmitry Abanin, Andreas Bengtsson, Ilya Drozdov, Catherine Erickson, Paul Klimov, Xiao Mi, Alexis Morvan, Matthew Neeley, Charles Neill, Rajeev Acharya, Richard Allen, Kyle Anderson, Markus Ansmann, Frank Arute, Kunal Arya, Abraham Asfaw, Juan Atalaya, Joseph Bardin, A. Bilmes, Gina Bortoli, Alexandre Bourassa, Jenna Bovaird, Leon Brill, Michael Broughton, Bob B. Buckley, David Buell, Tim Burger, Brian Burkett, Nicholas Bushnell, Juan Campero, Hung-Shen Chang, Zijun Chen, Benjamin Chiaro, Desmond Chik, Josh Cogan, Roberto Collins, Paul Conner, William Courtney, Alexander Crook, Ben Curtin, Dripto Debroy, Alexander Del Toro Barba, Sean Demura, Agustin Di Paolo, Andrew Dunsworth, Clint Earle, E. Farhi, Reza Fatemi, Vinicius Ferreira, Leslie Flores, Ebrahim Forati, Austin Fowler, Brooks Foxen, Gonzalo Garcia, Élie Genois, William Giang, Craig Gidney, Dar Gilboa, Marissa Giustina, Raja Gosula, Alejandro Grajales Dau, Jonathan Gross, Steve Habegger, Michael Hamilton, Monica Hansen, Matthew Harrigan, Sean Harrington, Paula Heu, Gordon Hill, Markus Hoffmann, Sabrina Hong, Trent Huang, Ashley Huff, William Huggins, Lev Ioffe, Sergei Isakov, Justin Iveland, Evan Jeffrey, Zhang Jiang, Cody Jones, Pavol Juhas, D. Kafri, Tanuj Khattar, Mostafa Khezri, Mária Kieferová, Seon Kim, Alexei Kitaev, Andrey Klots, Alexander Korotkov, Fedor Kostritsa, John Mark Kreikebaum, David Landhuis, Pavel Laptev, Kim Ming Lau, Lily Laws, Joonho Lee, Kenneth Lee, Yuri Lensky, Brian Lester, Alexander Lill, Wayne Liu, William P. Livingston, A. Locharla, Salvatore Mandrà, Orion Martin, Steven Martin, Jarrod McClean, Matthew McEwen, Seneca Meeks, Kevin Miao, Amanda Mieszala, Shirin Montazeri, Ramis Movassagh, Wojciech Mruczkiewicz, Ani Nersisyan, Michael Newman, Jiun How Ng, Anthony Nguyen, Murray Nguyen, M. Niu, Thomas O'Brien, Seun Omonije, Alex Opremcak, Rebecca Potter, Leonid Pryadko, Chris Quintana, David Rhodes, Charles Rocque, N. Rubin, Negar Saei, Daniel Sank, Kannan Sankaragomathi, Kevin Satzinger, Henry Schurkus, Christopher Schuster, Michael Shearn, Aaron Shorter, Noah Shutty, Vladimir Shvarts, Volodymyr Sivak, Jindra Skruzny, Clarke Smith, Rolando Somma, George Sterling, Doug Strain, Marco Szalay, Douglas Thor, Alfredo Torres, Guifre Vidal, Benjamin Villalonga, Catherine Vollgraff Heidweiller, Theodore White, Bryan Woo, Cheng Xing, Jamie Yao, Ping Yeh, Juhwan Yoo, Grayson Young, Adam Zalcman, Yaxing Zhang, Ningfeng Zhu, Nicholas Zobrist, Hartmut Neven, Ryan Babbush, Dave Bacon, Sergio Boixo, Jeremy Hilton, Erik Lucero, Anthony Megrant, Julian Kelly, Yu Chen, Vadim Smelyanskiy, Vedika Khemani, Sarang Gopalakrishnan, Tomaž Prosen, Pedram Roushan Session W50: Quantum Simulation of Many-Body Physics Link to Paper The fast multipole method on a quantum computer Presenter: Kianna Wan Authors: Kianna Wan, Dominic W Berry, Ryan Babbush Session W50: Quantum Simulation of Many-Body Physics Friday The quantum computing industry and protecting national security: what tools will work? Presenter: Kate Weber Author: Kate Weber Session Y43: Industry, Innovation, and National Security: Finding the Right Balance Novel charging effects in the fluxonium qubit Presenter: Agustin Di Paolo Authors: Agustin Di Paolo, Kyle Serniak, Andrew J Kerman, William D Oliver Session Y46: Fluxonium-Based Superconducting Quibits Microwave Engineering of Parametric Interactions in Superconducting Circuits Presenter: Ofer Naaman Author: Ofer Naaman Session Z46: Broadband Parametric Amplifiers and Circulators Linear spin wave theory of large magnetic unit cells using the Kernel Polynomial Method Presenter: Harry Lane Authors: Harry Lane, Hao Zhang, David A Dahlbom, Sam Quinn, Rolando D Somma, Martin P Mourigal, Cristian D Batista, Kipton Barros Session Z62: Cooperative Phenomena, Theory *Work done while at Google
Posted by Long Zhao, Senior Research Scientist, and Ting Liu, Senior Staff Software Engineer, Google Research An astounding number of videos are available on the Web, covering a variety of content from everyday moments people share to historical moments to scientific observations, each of which contains a unique record of the world. The right tools could help researchers analyze these videos, transforming how we understand the world around us. Videos offer dynamic visual content far more rich than static images, capturing movement, changes, and dynamic relationships between entities. Analyzing this complexity, along with the immense diversity of publicly available video data, demands models that go beyond traditional image understanding. Consequently, many of the approaches that best perform on video understanding still rely on specialized models tailor-made for particular tasks. Recently, there has been exciting progress in this area using video foundation models (ViFMs), such as VideoCLIP, InternVideo, VideoCoCa, and UMT. However, building a ViFM that handles the sheer diversity of video data remains a challenge. With the goal of building a single model for general-purpose video understanding, we introduce “VideoPrism: A Foundational Visual Encoder for Video Understanding”. VideoPrism is a ViFM designed to handle a wide spectrum of video understanding tasks, including classification, localization, retrieval, captioning, and question answering (QA). We propose innovations in both the pre-training data as well as the modeling strategy. We pre-train VideoPrism on a massive and diverse dataset: 36 million high-quality video-text pairs and 582 million video clips with noisy or machine-generated parallel text. Our pre-training approach is designed for this hybrid data, to learn both from video-text pairs and the videos themselves. VideoPrism is incredibly easy to adapt to new video understanding challenges, and achieves state-of-the-art performance using a single frozen model. VideoPrism is a general-purpose video encoder that enables state-of-the-art results over a wide spectrum of video understanding tasks, including classification, localization, retrieval, captioning, and question answering, by producing video representations from a single frozen model. Pre-training data YT-Temporal-180M, InternVid, VideoCC, WTS-70M, etc. This includes 36 million carefully selected videos with high-quality captions, along with an additional 582 million clips with varying levels of noisy text (like auto-generated transcripts). To our knowledge, this is the largest and most diverse video training corpus of its kind. Statistics on the video-text pre-training data. The large variations of the CLIP similarity scores (the higher, the better) demonstrate the diverse caption quality of our pre-training data, which is a byproduct of the various ways used to harvest the text. Two-stage training vision transformer (ViT) with a factorized design that sequentially encodes spatial and temporal information following ViViT. Our training approach leverages both the high-quality video-text data and the video data with noisy text mentioned above. To start, we use contrastive learning (an approach that minimizes the distance between positive video-text pairs while maximizing the distance between negative video-text pairs) to teach our model to match videos with their own text descriptions, including imperfect ones. This builds a foundation for matching semantic language content to visual content. After video-text contrastive training, we leverage the collection of videos without text descriptions. Here, we build on the masked video modeling framework to predict masked patches in a video, with a few improvements. We train the model to predict both the video-level global embedding and token-wise embeddings from the first-stage model to effectively leverage the knowledge acquired in that stage. We then randomly shuffle the predicted tokens to prevent the model from learning shortcuts. What is unique about VideoPrism’s setup is that we use two complementary pre-training signals: text descriptions and the visual content within a video. Text descriptions often focus on what things look like, while the video content provides information about movement and visual dynamics. This enables VideoPrism to excel in tasks that demand an understanding of both appearance and motion. Results VideoPrism compared to the previous best-performing FMs. Classification and localization VideoGLUE) covering classification and localization tasks. We find that (1) VideoPrism outperforms all of the other state-of-the-art FMs, and (2) no other single model consistently came in second place. This tells us that VideoPrism has learned to effectively pack a variety of video signals into one encoder — from semantics at different granularities to appearance and motion cues — and it works well across a variety of video sources. VideoPrism outperforms state-of-the-art approaches (including CLIP, VATT, InternVideo, and UMT) on the video understanding benchmark. In this plot, we show the absolute score differences compared with the previous best model to highlight the relative improvements of VideoPrism. On Charades, ActivityNet, AVA, and AVA-K, we use mean average precision (mAP) as the evaluation metric. On the other datasets, we report top-1 accuracy. Combining with LLMs LiT) or a language decoder (such as PaLM-2), VideoPrism can be utilized for video-text retrieval, video captioning, and video QA tasks. We compare the combined models on a broad and challenging set of vision-language benchmarks. VideoPrism sets the new state of the art on most benchmarks. From the visual results, we find that VideoPrism is capable of understanding complex motions and appearances in videos (e.g., the model can recognize the different colors of spinning objects on the window in the visual examples below). These results demonstrate that VideoPrism is strongly compatible with language models. VideoPrism achieves competitive results compared with state-of-the-art approaches (including VideoCoCa, UMT and Flamingo) on multiple video-text retrieval (top) and video captioning and video QA (bottom) benchmarks. We also show the absolute score differences compared with the previous best model to highlight the relative improvements of VideoPrism. We report the Recall@1 on MASRVTT, VATEX, and ActivityNet, CIDEr score on MSRVTT-Cap, VATEX-Cap, and YouCook2, top-1 accuracy on MSRVTT-QA and MSVD-QA, and WUPS index on NExT-QA. We show qualitative results using VideoPrism with a text encoder for video-text retrieval (first row) and adapted to a language decoder for video QA (second and third row). For video-text retrieval examples, the blue bars indicate the embedding similarities between the videos and the text queries. Scientific applications Fly vs. Fly, CalMS21, ChimpACT, and KABR. VideoPrism not only performs exceptionally well, but actually surpasses models designed specifically for those tasks. This suggests tools like VideoPrism have the potential to transform how scientists analyze video data across different fields. VideoPrism outperforms the domain experts on various scientific benchmarks. We show the absolute score differences to highlight the relative improvements of VideoPrism. We report mean average precision (mAP) for all datasets, except for KABR which uses class-averaged top-1 accuracy. Conclusion AI Principles. We hope VideoPrism paves the way for future breakthroughs at the intersection of AI and video analysis, helping to realize the potential of ViFMs across domains such as scientific discovery, education, and healthcare. Acknowledgements This blog post is made on behalf of all the VideoPrism authors: Long Zhao, Nitesh B. Gundavarapu, Liangzhe Yuan, Hao Zhou, Shen Yan, Jennifer J. Sun, Luke Friedman, Rui Qian, Tobias Weyand, Yue Zhao, Rachel Hornung, Florian Schroff, Ming-Hsuan Yang, David A. Ross, Huisheng Wang, Hartwig Adam, Mikhail Sirotenko, Ting Liu, and Boqing Gong. We sincerely thank David Hendon for their product management efforts, and Alex Siegman, Ramya Ganeshan, and Victor Gomes for their program and resource management efforts. We also thank Hassan Akbari, Sherry Ben, Yoni Ben-Meshulam, Chun-Te Chu, Sam Clearwater, Yin Cui, Ilya Figotin, Anja Hauth, Sergey Ioffe, Xuhui Jia, Yeqing Li, Lu Jiang, Zu Kim, Dan Kondratyuk, Bill Mark, Arsha Nagrani, Caroline Pantofaru, Sushant Prakash, Cordelia Schmid, Bryan Seybold, Mojtaba Seyedhosseini, Amanda Sadler, Rif A. Saurous, Rachel Stigler, Paul Voigtlaender, Pingmei Xu, Chaochao Yan, Xuan Yang, and Yukun Zhu for the discussions, support, and feedback that greatly contributed to this work. We are grateful to Jay Yagnik, Rahul Sukthankar, and Tomas Izo for their enthusiastic support for this project. Lastly, we thank Tom Small, Jennifer J. Sun, Hao Zhou, Nitesh B. Gundavarapu, Luke Friedman, and Mikhail Sirotenko for the tremendous help with making this blog post.
Posted by Zheng Xu, Research Scientist, and Yanxiang Zhang, Software Engineer, Google Language models (LMs) trained to predict the next word given input text are the key technology for many applications [1, 2]. In Gboard, LMs are used to improve users’ typing experience by supporting features like next word prediction (NWP), Smart Compose, smart completion and suggestion, slide to type, and proofread. Deploying models on users’ devices rather than enterprise servers has advantages like lower latency and better privacy for model usage. While training on-device models directly from user data effectively improves the utility performance for applications such as NWP and smart text selection, protecting the privacy of user data for model training is important. Gboard features powered by on-device language models. federated learning (FL) in 2017 and formal differential privacy (DP) guarantees in 2022. FL enables mobile phones to collaboratively learn a model while keeping all the training data on device, and DP provides a quantifiable measure of data anonymization. Formally, DP is often characterized by (ε, δ) with smaller values representing stronger guarantees. Machine learning (ML) models are considered to have reasonable DP guarantees for ε=10 and strong DP guarantees for ε=1 when δ is small. As of today, all NWP neural network LMs in Gboard are trained with FL with formal DP guarantees, and all future launches of Gboard LMs trained on user data require DP. These 30+ Gboard on-device LMs are launched in 7+ languages and 15+ countries, and satisfy (ɛ, δ)-DP guarantees of small δ of 10-10 and ɛ between 0.994 and 13.69. To the best of our knowledge, this is the largest known deployment of user-level DP in production at Google or anywhere, and the first time a strong DP guarantee of ɛ
Posted by Nishant Jain, Pre-doctoral Researcher, and Pradeep Shenoy, Research Scientist, Google Research The constantly changing nature of the world around us poses a significant challenge for the development of AI models. Often, models are trained on longitudinal data with the hope that the training data used will accurately represent inputs the model may receive in the future. More generally, the default assumption that all training data are equally relevant often breaks in practice. For example, the figure below shows images from the CLEAR nonstationary learning benchmark, and it illustrates how visual features of objects evolve significantly over a 10 year span (a phenomenon we refer to as slow concept drift), posing a challenge for object categorization models. Sample images from the CLEAR benchmark. (Adapted from Lin et al.) online and continual learning, repeatedly update a model with small amounts of recent data in order to keep it current. This implicitly prioritizes recent data, as the learnings from past data are gradually erased by subsequent updates. However in the real world, different kinds of information lose relevance at different rates, so there are two key issues: 1) By design they focus exclusively on the most recent data and lose any signal from older data that is erased. 2) Contributions from data instances decay uniformly over time irrespective of the contents of the data. In our recent work, “Instance-Conditional Timescales of Decay for Non-Stationary Learning”, we propose to assign each instance an importance score during training in order to maximize model performance on future data. To accomplish this, we employ an auxiliary model that produces these scores using the training instance as well as its age. This model is jointly learned with the primary model. We address both the above challenges and achieve significant gains over other robust learning methods on a range of benchmark datasets for nonstationary learning. For instance, on a recent large-scale benchmark for nonstationary learning (~39M photos over a 10 year period), we show up to 15% relative accuracy gains through learned reweighting of training data. The challenge of concept drift for supervised learning recent photo categorization task, comprising roughly 39M photographs sourced from social media websites over a 10 year period. We compared offline training, which iterated over all the training data multiple times in random order, and continual training, which iterated multiple times over each month of data in sequential (temporal) order. We measured model accuracy both during the training period and during a subsequent period where both models were frozen, i.e., not updated further on new data (shown below). At the end of the training period (left panel, x-axis = 0), both approaches have seen the same amount of data, but show a large performance gap. This is due to catastrophic forgetting, a problem in continual learning where a model’s knowledge of data from early on in the training sequence is diminished in an uncontrolled manner. On the other hand, forgetting has its advantages — over the test period (shown on the right), the continual trained model degrades much less rapidly than the offline model because it is less dependent on older data. The decay of both models’ accuracy in the test period is confirmation that the data is indeed evolving over time, and both models become increasingly less relevant. Comparing offline and continually trained models on the photo classification task. Time-sensitive reweighting of training data M, given some training data collected over time. We propose to also train a helper model that assigns a weight to each point based on its contents and age. This weight scales the contribution from that data point in the training objective for M. The objective of the weights is to improve the performance of M on future data. In our work, we describe how the helper model can be meta-learned, i.e., learned alongside M in a manner that helps the learning of the model M itself. A key design choice of the helper model is that we separated out instance- and age-related contributions in a factored manner. Specifically, we set the weight by combining contributions from multiple different fixed timescales of decay, and learn an approximate “assignment” of a given instance to its most suited timescales. We find in our experiments that this form of the helper model outperforms many other alternatives we considered, ranging from unconstrained joint functions to a single timescale of decay (exponential or linear), due to its combination of simplicity and expressivity. Full details may be found in the paper. Instance weight scoring CLEAR object recognition challenge; older-looking objects are correspondingly down-weighted. On closer examination (bottom figure below, gradient-based feature importance assessment), we see that the helper model focuses on the primary object within the image, as opposed to, e.g., background features that may spuriously be correlated with instance age. Sample images from the CLEAR benchmark (camera & computer categories) assigned the highest and lowest weights respectively by our helper model. Feature importance analysis of our helper model on sample images from the CLEAR benchmark. Results Gains on large-scale data photo categorization task (PCAT) on the YFCC100M dataset discussed earlier, using the first five years of data for training and the next five years as test data. Our method (shown in red below) improves substantially over the no-reweighting baseline (black) as well as many other robust learning techniques. Interestingly, our method deliberately trades off accuracy on the distant past (training data unlikely to reoccur in the future) in exchange for marked improvements in the test period. Also, as desired, our method degrades less than other baselines in the test period. Comparison of our method and relevant baselines on the PCAT dataset. Broad applicability 1, 2, 3, 4 for details) that spans data sources and modalities (photos, satellite images, social media text, medical records, sensor readings, tabular data) and sizes (ranging from 10k to 39M instances). We report significant gains in the test period when compared to the nearest published benchmark method for each dataset (shown below). Note that the previous best-known method may be different for each dataset. These results showcase the broad applicability of our approach. Performance gain of our method on a variety of tasks studying natural concept drift. Our reported gains are over the previous best-known method for each dataset. Extensions to continual learning within the context of each bucket of data being used to sequentially update the model. This proposal still retains some limitations of continual learning, e.g., model updates are performed only on most-recent data, and all optimization decisions (including our reweighting) are only made over that data. Nevertheless, our approach consistently beats regular continual learning as well as a wide range of other continual learning algorithms on the photo categorization benchmark (see below). Since our approach is complementary to the ideas in many baselines compared here, we anticipate even larger gains when combined with them. Results of our method adapted to continual learning, compared to the latest baselines. Conclusion Acknowledgements We thank Mike Mozer for many interesting discussions in the early phase of this work, as well as very helpful advice and feedback during its development.
Posted by Mónica Ribero Díaz, Research Scientist, Google Research Differential privacy (DP) is a property of randomized mechanisms that limit the influence of any individual user’s information while processing and analyzing data. DP offers a robust solution to address growing concerns about data protection, enabling technologies across industries and government applications (e.g., the US census) without compromising individual user identities. As its adoption increases, it’s important to identify the potential risks of developing mechanisms with faulty implementations. Researchers have recently found errors in the mathematical proofs of private mechanisms, and their implementations. For example, researchers compared six sparse vector technique (SVT) variations and found that only two of the six actually met the asserted privacy guarantee. Even when mathematical proofs are correct, the code implementing the mechanism is vulnerable to human error. However, practical and efficient DP auditing is challenging primarily due to the inherent randomness of the mechanisms and the probabilistic nature of the tested guarantees. In addition, a range of guarantee types exist, (e.g., pure DP, approximate DP, Rényi DP, and concentrated DP), and this diversity contributes to the complexity of formulating the auditing problem. Further, debugging mathematical proofs and code bases is an intractable task given the volume of proposed mechanisms. While ad hoc testing techniques exist under specific assumptions of mechanisms, few efforts have been made to develop an extensible tool for testing DP mechanisms. To that end, in “DP-Auditorium: A Large Scale Library for Auditing Differential Privacy”, we introduce an open source library for auditing DP guarantees with only black-box access to a mechanism (i.e., without any knowledge of the mechanism’s internal properties). DP-Auditorium is implemented in Python and provides a flexible interface that allows contributions to continuously improve its testing capabilities. We also introduce new testing algorithms that perform divergence optimization over function spaces for Rényi DP, pure DP, and approximate DP. We demonstrate that DP-Auditorium can efficiently identify DP guarantee violations, and suggest which tests are most suitable for detecting particular bugs under various privacy guarantees. DP guarantees M (D)) that satisfies a mathematical property ensuring the privacy of user data. A DP guarantee is thus tightly related to properties between pairs of probability distributions. A mechanism is differentially private if the probability distributions determined by M on dataset D and a neighboring dataset D’, which differ by only one record, are indistinguishable under a given divergence metric. For example, the classical approximate DP definition states that a mechanism is approximately DP with parameters (ε, δ) if the hockey-stick divergence of order eε, between M(D) and M(D’), is at most δ. Pure DP is a special instance of approximate DP where δ = 0. Finally, a mechanism is considered Rényi DP with parameters (𝛼, ε) if the Rényi divergence of order 𝛼, is at most ε (where ε is a small positive value). In these three definitions, ε is not interchangeable but intuitively conveys the same concept; larger values of ε imply larger divergences between the two distributions or less privacy, since the two distributions are easier to distinguish. DP-Auditorium gradient descent mechanism variants. Property testers determine if evidence exists to reject the hypothesis that a given divergence between two probability distributions, P and Q, is bounded by a prespecified budget determined by the DP guarantee being tested. They compute a lower bound from samples from P and Q, rejecting the property if the lower bound value exceeds the expected divergence. No guarantees are provided if the result is indeed bounded. To test for a range of privacy guarantees, DP-Auditorium introduces three novel testers: (1) HockeyStickPropertyTester, (2) RényiPropertyTester, and (3) MMDPropertyTester. Unlike other approaches, these testers don’t depend on explicit histogram approximations of the tested distributions. They rely on variational representations of the hockey-stick divergence, Rényi divergence, and maximum mean discrepancy (MMD) that enable the estimation of divergences through optimization over function spaces. As a baseline, we implement HistogramPropertyTester, a commonly used approximate DP tester. While our three testers follow a similar approach, for brevity, we focus on the HockeyStickPropertyTester in this post. Given two neighboring datasets, D and D’, the HockeyStickPropertyTester finds a lower bound,^δ for the hockey-stick divergence between M(D) and M(D’) that holds with high probability. Hockey-stick divergence enforces that the two distributions M(D) and M(D’) are close under an approximate DP guarantee. Therefore, if a privacy guarantee claims that the hockey-stick divergence is at most δ, and^δ > δ, then with high probability the divergence is higher than what was promised on D and D’ and the mechanism cannot satisfy the given approximate DP guarantee. The lower bound^δ is computed as an empirical and tractable counterpart of a variational formulation of the hockey-stick divergence (see the paper for more details). The accuracy of^δ increases with the number of samples drawn from the mechanism, but decreases as the variational formulation is simplified. We balance these factors in order to ensure that^δ is both accurate and easy to compute. Dataset finders use black-box optimization to find datasets D and D’ that maximize^δ, a lower bound on the divergence value δ. Note that black-box optimization techniques are specifically designed for settings where deriving gradients for an objective function may be impractical or even impossible. These optimization techniques oscillate between exploration and exploitation phases to estimate the shape of the objective function and predict areas where the objective can have optimal values. In contrast, a full exploration algorithm, such as the grid search method, searches over the full space of neighboring datasets D and D’. DP-Auditorium implements different dataset finders through the open sourced black-box optimization library Vizier. Running existing components on a new mechanism only requires defining the mechanism as a Python function that takes an array of data D and a desired number of samples n to be output by the mechanism computed on D. In addition, we provide flexible wrappers for testers and dataset finders that allow practitioners to implement their own testing and dataset search algorithms. Key results ε, and report the number of times each tester identifies privacy bugs. While no tester consistently outperforms the others, we identify bugs that would be missed by previous techniques (HistogramPropertyTester). Note that the HistogramPropertyTester is not applicable to SVT mechanisms. Number of times each property tester finds the privacy violation for the tested non-private mechanisms. NonDPLaplaceMean and NonDPGaussianMean mechanisms are faulty implementations of the Laplace and Gaussian mechanisms for computing the mean. DP gradient descent algorithm (DP-GD) in TensorFlow that computes gradients of the loss function on private data. To preserve privacy, DP-GD employs a clipping mechanism to bound the l2-norm of the gradients by a value G, followed by the addition of Gaussian noise. This implementation incorrectly assumes that the noise added has a scale of G, while in reality, the scale is sG, where s is a positive scalar. This discrepancy leads to an approximate DP guarantee that holds only for values of s greater than or equal to 1. We evaluate the effectiveness of property testers in detecting this bug and show that HockeyStickPropertyTester and RényiPropertyTester exhibit superior performance in identifying privacy violations, outperforming MMDPropertyTester and HistogramPropertyTester. Notably, these testers detect the bug even for values of s as high as 0.6. It is worth highlighting that s = 0.5 corresponds to a common error in literature that involves missing a factor of two when accounting for the privacy budget ε. DP-Auditorium successfully captures this bug as shown below. For more details see section 5.6 here. Estimated divergences and test thresholds for different values of s when testing DP-GD with the HistogramPropertyTester (left) and the HockeyStickPropertyTester (right). Estimated divergences and test thresholds for different values of s when testing DP-GD with the RényiPropertyTester (left) and the MMDPropertyTester (right) paper. Conclusion open sourcing DP-Auditorium, we aim to establish a standard for end-to-end testing of new differentially private algorithms. Acknowledgements The work described here was done jointly with Andrés Muñoz Medina, William Kong and Umar Syed. We thank Chris Dibak and Vadym Doroshenko for helpful engineering support and interface suggestions for our library.
Posted by Dustin Zelle, Software Engineer, Google Research, and Arno Eigenwillig, Software Engineer, CoreML Objects and their relationships are ubiquitous in the world around us, and relationships can be as important to understanding an object as its own attributes viewed in isolation — take for example transportation networks, production networks, knowledge graphs, or social networks. Discrete mathematics and computer science have a long history of formalizing such networks as graphs, consisting of nodes connected by edges in various irregular ways. Yet most machine learning (ML) algorithms allow only for regular and uniform relations between input objects, such as a grid of pixels, a sequence of words, or no relation at all. Graph neural networks, or GNNs for short, have emerged as a powerful technique to leverage both the graph’s connectivity (as in the older algorithms DeepWalk and Node2Vec) and the input features on the various nodes and edges. GNNs can make predictions for graphs as a whole (Does this molecule react in a certain way?), for individual nodes (What’s the topic of this document, given its citations?) or for potential edges (Is this product likely to be purchased together with that product?). Apart from making predictions about graphs, GNNs are a powerful tool used to bridge the chasm to more typical neural network use cases. They encode a graph's discrete, relational information in a continuous way so that it can be included naturally in another deep learning system. We are excited to announce the release of TensorFlow GNN 1.0 (TF-GNN), a production-tested library for building GNNs at large scales. It supports both modeling and training in TensorFlow as well as the extraction of input graphs from huge data stores. TF-GNN is built from the ground up for heterogeneous graphs, where types of objects and relations are represented by distinct sets of nodes and edges. Real-world objects and their relations occur in distinct types, and TF-GNN's heterogeneous focus makes it natural to represent them. Inside TensorFlow, such graphs are represented by objects of type tfgnn.GraphTensor. This is a composite tensor type (a collection of tensors in one Python class) accepted as a first-class citizen in tf.data.Dataset, tf.function, etc. It stores both the graph structure and its features attached to nodes, edges and the graph as a whole. Trainable transformations of GraphTensors can be defined as Layers objects in the high-level Keras API, or directly using the tfgnn.GraphTensor primitive. GNNs: Making predictions for an object in context Pictured, the process of subgraph sampling where small, tractable subgraphs are sampled from a larger graph to create input examples for GNN training. this one), for efficient sampling of a small dataset stored in the main memory of a single training host, or distributed by Apache Beam for huge datasets stored on a network filesystem (up to hundreds of millions of nodes and billions of edges). For details, please refer to our user guides for in-memory and beam-based sampling, respectively. On those same sampled subgraphs, the GNN’s task is to compute a hidden (or latent) state at the root node; the hidden state aggregates and encodes the relevant information of the root node's neighborhood. One classical approach is message-passing neural networks. In each round of message passing, nodes receive messages from their neighbors along incoming edges and update their own hidden state from them. After n rounds, the hidden state of the root node reflects the aggregate information from all nodes within n edges (pictured below for n = 2). The messages and the new hidden states are computed by hidden layers of the neural network. In a heterogeneous graph, it often makes sense to use separately trained hidden layers for the different types of nodes and edges Pictured, a simple message-passing neural network where, at each step, the node state is propagated from outer to inner nodes where it is pooled to compute new node states. Once the root node is reached, a final prediction can be made. loss (to measure the prediction error), and updating model weights by backpropagation, as usual in any neural network training. Beyond supervised training (i.e., minimizing a loss defined by labels), GNNs can also be trained in an unsupervised way (i.e., without labels). This lets us compute a continuous representation (or embedding) of the discrete graph structure of nodes and their features. These representations are then typically utilized in other ML systems. In this way, the discrete, relational information encoded by a graph can be included in more typical neural network use cases. TF-GNN supports a fine-grained specification of unsupervised objectives for heterogeneous graphs. Building GNN architectures GraphNets. This can, but need not, be done with Keras as a modeling framework on the top of core TensorFlow. For more details, and intermediate levels of modeling, see the TF-GNN user guide and model collection. Training orchestration TF-GNN Runner also provides a succinct way to orchestrate the training of Keras models in the common cases. A simple invocation may look like this: tfgnn.GraphTensor padding for fixed shapes on Cloud TPUs. Beyond training on a single task (as shown above), it supports joint training on multiple (two or more) tasks in concert. For example, unsupervised tasks can be mixed with supervised ones to inform a final continuous representation (or embedding) with application specific inductive biases. Callers only need substitute the task argument with a mapping of tasks: integrated gradients for use in model attribution. Integrated gradients output is a GraphTensor with the same connectivity as the observed GraphTensor but its features replaced with gradient values where larger values contribute more than smaller values in the GNN prediction. Users can inspect gradient values to see which features their GNN uses the most. Conclusion Colab demo with the popular OGBN-MAG benchmark (in your browser, no installation required), browse the rest of our user guides and Colabs, or take a look at our paper. Acknowledgements The TF-GNN release 1.0 was developed by a collaboration between Google Research: Sami Abu-El-Haija, Neslihan Bulut, Bahar Fatemi, Johannes Gasteiger, Pedro Gonnet, Jonathan Halcrow, Liangze Jiang, Silvio Lattanzi, Brandon Mayer, Vahab Mirrokni, Bryan Perozzi, Anton Tsitsulin, Dustin Zelle, Google Core ML: Arno Eigenwillig, Oleksandr Ferludin, Parth Kothari, Mihir Paradkar, Jan Pfeifer, Rachael Tamakloe, and Google DeepMind: Alvaro Sanchez-Gonzalez and Lisa Wang.
Posted by Rajat Sen and Yichen Zhou, Google Research Time-series forecasting is ubiquitous in various domains, such as retail, finance, manufacturing, healthcare and natural sciences. In retail use cases, for example, it has been observed that improving demand forecasting accuracy can meaningfully reduce inventory costs and increase revenue. Deep learning (DL) models have emerged as a popular approach for forecasting rich, multivariate, time-series data because they have proven to perform well in a variety of settings (e.g., DL models performed well in the M5 competition). At the same time, there has been rapid progress in large foundation language models used for natural language processing (NLP) tasks, such as translation, retrieval-augmented generation, and code completion. These models are trained on massive amounts of textual data derived from a variety of sources like common crawl and open-source code that allows them to identify patterns in languages. This makes them very powerful zero-shot tools; for instance, when paired with retrieval, they can answer questions about and summarize current events. Despite DL-based forecasters largely outperforming traditional methods and progress being made in reducing training and inference costs, they face challenges: most DL architectures require long and involved training and validation cycles before a customer can test the model on a new time-series. A foundation model for time-series forecasting, in contrast, can provide decent out-of-the-box forecasts on unseen time-series data with no additional training, enabling users to focus on refining forecasts for the actual downstream task like retail demand planning. To that end, in “A decoder-only foundation model for time-series forecasting”, we introduce TimesFM, a single forecasting model pre-trained on a large time-series corpus of 100 billion real world time-points. Compared to the latest large language models (LLMs), TimesFM is much smaller (200M parameters), yet we show that even at such scales, its zero-shot performance on a variety of unseen datasets of different domains and temporal granularities come close to the state-of-the-art supervised approaches trained explicitly on these datasets. Later this year we plan to make this model available for external customers in Google Cloud Vertex AI. A decoder-only foundation model for time-series forecasting decoder-only fashion that involves three steps. First, text is broken down into subwords called tokens. Then, the tokens are fed into stacked causal transformer layers that produce an output corresponding to each input token (it cannot attend to future tokens). Finally, the output corresponding to the i-th token summarizes all the information from previous tokens and predicts the (i+1)-th token. During inference, the LLM generates the output one token at a time. For example, when prompted with “What is the capital of France?”, it might generate the token “The”, then condition on “What is the capital of France? The” to generate the next token “capital” and so on until it generates the complete answer: “The capital of France is Paris”. A foundation model for time-series forecasting should adapt to variable context (what we observe) and horizon (what we query the model to forecast) lengths, while having enough capacity to encode all patterns from a large pretraining dataset. Similar to LLMs, we use stacked transformer layers (self-attention and feedforward layers) as the main building blocks for the TimesFM model. In the context of time-series forecasting, we treat a patch (a group of contiguous time-points) as a token that was popularized by a recent long-horizon forecasting work. The task then is to forecast the (i+1)-th patch of time-points given the i-th output at the end of the stacked transformer layers. However, there are several key differences from language models. Firstly, we need a multilayer perceptron block with residual connections to convert a patch of time-series into a token that can be input to the transformer layers along with positional encodings (PE). For that, we use a residual block similar to our prior work in long-horizon forecasting. Secondly, at the other end, an output token from the stacked transformer can be used to predict a longer length of subsequent time-points than the input patch length, i.e., the output patch length can be larger than the input patch length. Consider a time-series of length 512 time-points being used to train a TimesFM model with input patch length 32 and output patch length 128. During training, the model is simultaneously trained to use the first 32 time-points to forecast the next 128 time-points, the first 64 time-points to forecast time-points 65 to 192, the first 96 time-points to forecast time-points 97 to 224 and so on. During inference, suppose the model is given a new time-series of length 256 and tasked with forecasting the next 256 time-points into the future. The model will first generate the future predictions for time-points 257 to 384, then condition on the initial 256 length input plus the generated output to generate time-points 385 to 512. On the other hand, if in our model the output patch length was equal to the input patch length of 32 then for the same task we would have to go through eight generation steps instead of just the two above. This increases the chances of more errors accumulating and therefore, in practice, we see that a longer output patch length yields better performance for long-horizon forecasting TimesFM architecture. Pretraining data Synthetic data helps with the basics. Meaningful synthetic time-series data can be generated using statistical models or physical simulations. These basic temporal patterns can teach the model the grammar of time series forecasting. Real-world data adds real-world flavor. We comb through available public time series datasets, and selectively put together a large corpus of 100 billion time-points. Among these datasets there are Google Trends and Wikipedia Pageviews, which track what people are interested in, and that nicely mirrors trends and patterns in many other real-world time series. This helps TimesFM understand the bigger picture and generalize better when provided with domain-specific contexts not seen during training. Zero-shot evaluation results ARIMA, ETS and can match or outperform powerful DL models like DeepAR, PatchTST that have been explicitly trained on the target time-series. We used the Monash Forecasting Archive to evaluate TimesFM’s out-of-the-box performance. This archive contains tens of thousands of time-series from various domains like traffic, weather, and demand forecasting covering frequencies ranging from few minutes to yearly data. Following existing literature, we inspect the mean absolute error (MAE) appropriately scaled so that it can be averaged across the datasets. We see that zero-shot (ZS) TimesFM is better than most supervised approaches, including recent deep learning models. We also compare TimesFM to GPT-3.5 for forecasting using a specific prompting technique proposed by llmtime(ZS). We demonstrate that TimesFM performs better than llmtime(ZS) despite being orders of magnitude smaller. Scaled MAE (the lower the better) of TimesFM(ZS) against other supervised and zero-shot approaches on Monash datasets. PatchTST (and other long-horizon forecasting baselines). In the next figure, we plot the MAE on ETT datasets for the task of predicting 96 and 192 time-points into the future. The metric has been calculated on the last test window of each dataset (as done by the llmtime paper). We see that TimesFM not only surpasses the performance of llmtime(ZS) but also matches that of the supervised PatchTST model explicitly trained on the respective datasets. Last window MAE (the lower the better) of TimesFM(ZS) against llmtime(ZS) and long-horizon forecasting baselines on ETT datasets. Conclusion Acknowledgements This work is the result of a collaboration between several individuals across Google Research and Google Cloud, including (in alphabetical order): Abhimanyu Das, Weihao Kong, Andrew Leach, Mike Lawrence, Alex Martin, Rajat Sen, Yang Yang, Skander Hannachi, Ivan Kuznetsov and Yichen Zhou.
Posted by Rishabh Tiwari, Pre-doctoral Researcher, and Pradeep Shenoy, Research Scientist, Google Research Machine learning models in the real world are often trained on limited data that may contain unintended statistical biases. For example, in the CELEBA celebrity image dataset, a disproportionate number of female celebrities have blond hair, leading to classifiers incorrectly predicting “blond” as the hair color for most female faces — here, gender is a spurious feature for predicting hair color. Such unfair biases could have significant consequences in critical applications such as medical diagnosis. Surprisingly, recent work has also discovered an inherent tendency of deep networks to amplify such statistical biases, through the so-called simplicity bias of deep learning. This bias is the tendency of deep networks to identify weakly predictive features early in the training, and continue to anchor on these features, failing to identify more complex and potentially more accurate features. With the above in mind, we propose simple and effective fixes to this dual challenge of spurious features and simplicity bias by applying early readouts and feature forgetting. First, in “Using Early Readouts to Mediate Featural Bias in Distillation”, we show that making predictions from early layers of a deep network (referred to as “early readouts”) can automatically signal issues with the quality of the learned representations. In particular, these predictions are more often wrong, and more confidently wrong, when the network is relying on spurious features. We use this erroneous confidence to improve outcomes in model distillation, a setting where a larger “teacher” model guides the training of a smaller “student” model. Then in “Overcoming Simplicity Bias in Deep Networks using a Feature Sieve”, we intervene directly on these indicator signals by making the network “forget” the problematic features and consequently look for better, more predictive features. This substantially improves the model’s ability to generalize to unseen domains compared to previous approaches. Our AI Principles and our Responsible AI practices guide how we research and develop these advanced applications and help us address the challenges posed by statistical biases. Animation comparing hypothetical responses from two models trained with and without the feature sieve. Early readouts for debiasing distillation early readouts and their application in debiased distillation, i.e., making sure that the student model inherits the teacher model’s resilience to feature bias through distillation. We start with a standard distillation framework where the student is trained with a mixture of label matching (minimizing the cross-entropy loss between student outputs and the ground-truth labels) and teacher matching (minimizing the KL divergence loss between student and teacher outputs for any given input). Suppose one trains a linear decoder, i.e., a small auxiliary neural network named as Aux, on top of an intermediate representation of the student model. We refer to the output of this linear decoder as an early readout of the network representation. Our finding is that early readouts make more errors on instances that contain spurious features, and further, the confidence on those errors is higher than the confidence associated with other errors. This suggests that confidence on errors from early readouts is a fairly strong, automated indicator of the model’s dependence on potentially spurious features. Illustrating the usage of early readouts (i.e., output from the auxiliary layer) in debiasing distillation. Instances that are confidently mispredicted in the early readouts are upweighted in the distillation loss. Waterbirds, CelebA, CivilComments, MNLI). Each of these datasets contain groupings of data that share an attribute potentially correlated with the label in a spurious manner. As an example, the CelebA dataset mentioned above includes groups such as {blond male, blond female, non-blond male, non-blond female}, with models typically performing the worst on the {non-blond female} group when predicting hair color. Thus, a measure of model performance is its worst group accuracy, i.e., the lowest accuracy among all known groups present in the dataset. We improved the worst group accuracy of student models on all datasets; moreover, we also improved overall accuracy in three of the four datasets, showing that our improvement on any one group does not come at the expense of accuracy on other groups. More details are available in our paper. Comparison of Worst Group Accuracies of different distillation techniques relative to that of the Teacher model. Our method outperforms other methods on all datasets. Overcoming simplicity bias with a feature sieve feature learning and generalization. The workflow alternates between identifying problematic features and erasing identified features from the network. Our primary hypothesis is that early features are more prone to simplicity bias, and that by erasing (“sieving”) these features, we allow richer feature representations to be learned. Training workflow with feature sieve. We alternate between identifying problematic features (using training iteration) and erasing them from the network (using forgetting iteration). Identifying simple features: We train the primary model and the readout model (AUX above) in conventional fashion via forward- and back-propagation. Note that feedback from the auxiliary layer does not back-propagate to the main network. This is to force the auxiliary layer to learn from already-available features rather than create or reinforce them in the main network. Applying the feature sieve: We aim to erase the identified features in the early layers of the neural network with the use of a novel forgetting loss, Lf , which is simply the cross-entropy between the readout and a uniform distribution over labels. Essentially, all information that leads to nontrivial readouts are erased from the primary network. In this step, the auxiliary network and upper layers of the main network are kept unchanged. We can control specifically how the feature sieve is applied to a given dataset through a small number of configuration parameters. By changing the position and complexity of the auxiliary network, we control the complexity of the identified- and erased features. By modifying the mixing of learning and forgetting steps, we control the degree to which the model is challenged to learn more complex features. These choices, which are dataset-dependent, are made via hyperparameter search to maximize validation accuracy, a standard measure of generalization. Since we include “no-forgetting” (i.e., the baseline model) in the search space, we expect to find settings that are at least as good as the baseline. Below we show features learned by the baseline model (middle row) and our model (bottom row) on two benchmark datasets — biased activity recognition (BAR) and animal categorization (NICO). Feature importance was estimated using post-hoc gradient-based importance scoring (GRAD-CAM), with the orange-red end of the spectrum indicating high importance, while green-blue indicates low importance. Shown below, our trained models focus on the primary object of interest, whereas the baseline model tends to focus on background features that are simpler and spuriously correlated with the label. Feature importance scoring using GRAD-CAM on activity recognition (BAR) and animal categorization (NICO) generalization benchmarks. Our approach (last row) focuses on the relevant objects in the image, whereas the baseline (ERM; middle row) relies on background features that are spuriously correlated with the label. BAR, CelebA Hair, NICO and ImagenetA, by margins up to 11% (see figure below). More details are available in our paper. Our feature sieve method improves accuracy by significant margins relative to the nearest baseline for a range of feature generalization benchmark datasets. Conclusion Acknowledgements The work on applying early readouts to debiasing distillation was conducted in collaboration with our academic partners Durga Sivasubramanian, Anmol Reddy and Prof. Ganesh Ramakrishnan at IIT Bombay. We extend our sincere gratitude to Praneeth Netrapalli and Anshul Nasery for their feedback and recommendations. We are also grateful to Nishant Jain, Shreyas Havaldar, Rachit Bansal, Kartikeya Badola, Amandeep Kaur and the whole cohort of pre-doctoral researchers at Google Research India for taking part in research discussions. Special thanks to Tom Small for creating the animation used in this post.
Posted by Yang Zhao, Senior Software Engineer, and Tingbo Hou, Senior Staff Software Engineer, Core ML Text-to-image diffusion models have shown exceptional capabilities in generating high-quality images from text prompts. However, leading models feature billions of parameters and are consequently expensive to run, requiring powerful desktops or servers (e.g., Stable Diffusion, DALL·E, and Imagen). While recent advancements in inference solutions on Android via MediaPipe and iOS via Core ML have been made in the past year, rapid (sub-second) text-to-image generation on mobile devices has remained out of reach. To that end, in “MobileDiffusion: Subsecond Text-to-Image Generation on Mobile Devices”, we introduce a novel approach with the potential for rapid text-to-image generation on-device. MobileDiffusion is an efficient latent diffusion model specifically designed for mobile devices. We also adopt DiffusionGAN to achieve one-step sampling during inference, which fine-tunes a pre-trained diffusion model while leveraging a GAN to model the denoising step. We have tested MobileDiffusion on iOS and Android premium devices, and it can run in half a second to generate a 512x512 high-quality image. Its comparably small model size of just 520M parameters makes it uniquely suited for mobile deployment. Rapid text-to-image generation on-device. Background iterative denoising to generate images, necessitating multiple evaluations of the model. Second, the complexity of the network architecture in text-to-image diffusion models involves a substantial number of parameters, regularly reaching into the billions and resulting in computationally expensive evaluations. As a result, despite the potential benefits of deploying generative models on mobile devices, such as enhancing user experience and addressing emerging privacy concerns, it remains relatively unexplored within the current literature. The optimization of inference efficiency in text-to-image diffusion models has been an active research area. Previous studies predominantly concentrate on addressing the first challenge, seeking to reduce the number of function evaluations (NFEs). Leveraging advanced numerical solvers (e.g., DPM) or distillation techniques (e.g., progressive distillation, consistency distillation), the number of necessary sampling steps have significantly reduced from several hundreds to single digits. Some recent techniques, like DiffusionGAN and Adversarial Diffusion Distillation, even reduce to a single necessary step. However, on mobile devices, even a small number of evaluation steps can be slow due to the complexity of model architecture. Thus far, the architectural efficiency of text-to-image diffusion models has received comparatively less attention. A handful of earlier works briefly touches upon this matter, involving the removal of redundant neural network blocks (e.g., SnapFusion). However, these efforts lack a comprehensive analysis of each component within the model architecture, thereby falling short of providing a holistic guide for designing highly efficient architectures. MobileDiffusion UNet architecture. We present a comprehensive guide for crafting highly efficient text-to-image diffusion models culminating in the MobileDiffusion. The design of MobileDiffusion follows that of latent diffusion models. It contains three components: a text encoder, a diffusion UNet, and an image decoder. For the text encoder, we use CLIP-ViT/L14, which is a small model (125M parameters) suitable for mobile. We then turn our focus to the diffusion UNet and image decoder. Diffusion UNet UViT architecture, which places more transformer blocks at the bottleneck of the UNet. This design choice is motivated by the fact that the attention computation is less resource-intensive at the bottleneck due to its lower dimensionality. Our UNet architecture incorporates more transformers in the middle, and skips self-attention (SA) layers at higher resolutions. ResNet blocks, are deployed at each level of the UNet. While these blocks are instrumental for feature extraction and information flow, the associated computational costs, especially at high-resolution levels, can be substantial. One proven approach in this context is separable convolution. We observed that replacing regular convolution layers with lightweight separable convolution layers in the deeper segments of the UNet yields similar performance. In the figure below, we compare the UNets of several diffusion models. Our MobileDiffusion exhibits superior efficiency in terms of FLOPs (floating-point operations) and number of parameters. Comparison of some diffusion UNets. Image decoder variational autoencoder (VAE) to encode an RGB image to an 8-channel latent variable, with 8× smaller spatial size of the image. A latent variable can be decoded to an image and gets 8× larger in size. To further enhance efficiency, we design a lightweight decoder architecture by pruning the original’s width and depth. The resulting lightweight decoder leads to a significant performance boost, with nearly 50% latency improvement and better quality. For more details, please refer to our paper. VAE reconstruction. Our VAE decoders have better visual quality than SD (Stable Diffusion). Decoder #Params (M) PSNR↑ SSIM↑ LPIPS↓ SD 49.5 26.7 0.76 0.037 Ours 39.3 30.0 0.83 0.032 Ours-Lite 9.8 30.2 0.84 0.032 Quality evaluation of VAE decoders. Our lite decoder is much smaller than SD, with better quality metrics, including peak signal-to-noise ratio (PSNR), structural similarity index measure (SSIM), and Learned Perceptual Image Patch Similarity (LPIPS). One-step sampling DiffusionGAN hybrid to achieve one-step sampling. Training DiffusionGAN hybrid models for text-to-image generation encounters several intricacies. Notably, the discriminator, a classifier distinguishing real data and generated data, must make judgments based on both texture and semantics. Moreover, the cost of training text-to-image models can be extremely high, particularly in the case of GAN-based models, where the discriminator introduces additional parameters. Purely GAN-based text-to-image models (e.g., StyleGAN-T, GigaGAN) confront similar complexities, resulting in highly intricate and expensive training. To overcome these challenges, we use a pre-trained diffusion UNet to initialize the generator and discriminator. This design enables seamless initialization with the pre-trained diffusion model. We postulate that the internal features within the diffusion model contain rich information of the intricate interplay between textual and visual data. This initialization strategy significantly streamlines the training. The figure below illustrates the training procedure. After initialization, a noisy image is sent to the generator for one-step diffusion. The result is evaluated against ground truth with a reconstruction loss, similar to diffusion model training. We then add noise to the output and send it to the discriminator, whose result is evaluated with a GAN loss, effectively adopting the GAN to model a denoising step. By using pre-trained weights to initialize the generator and the discriminator, the training becomes a fine-tuning process, which converges in less than 10K iterations. Illustration of DiffusionGAN fine-tuning. Results Images generated by our MobileDiffusion Latency measurements (s) on mobile devices. Conclusion responsible AI practices. Acknowledgments We like to thank our collaborators and contributors that helped bring MobileDiffusion to on-device: Zhisheng Xiao, Yanwu Xu, Jiuqiang Tang, Haolin Jia, Lutz Justen, Daniel Fenner, Ronald Wotzlaw, Jianing Wei, Raman Sarokin, Juhyun Lee, Andrei Kulik, Chuo-Ling Chang, and Matthias Grundmann.
Posted by Manish Gupta, Staff Software Engineer, Google Research AI-driven technologies are weaving themselves into the fabric of our daily routines, with the potential to enhance our access to knowledge and boost our overall productivity. The backbone of these applications lies in large language models (LLMs). LLMs are memory-intensive and typically require specialized hardware accelerators to efficiently deliver tens of exaflops of computing power. This blog post shows how we can start addressing the computational challenges by utilizing memory more effectively. The bulk of an LLM’s memory and compute are consumed by weights in matrix multiplication operations. Using narrower data types reduces memory consumption. For example, storing weights in the 8-bit integer (i.e., U8 or S8) data type reduces the memory footprint by 4× relative to single-precision (F32) and 2× relative to half-precision (F16) or bfloat16 (BF16). Furthermore, previous work has shown that LLM models running matrix multiplications with weights in S8 and input in F16 (preserving higher precision of the user-input) is an effective method for increasing the efficiency with acceptable trade-offs in accuracy. This technique is known as weight-only quantization and requires efficient implementation of matrix multiplication with mixed-inputs, e.g., half-precision input multiplied with 8-bits integer. Hardware accelerators, including GPUs, support a fixed set of data types, and thus, mixed-input matrix multiplication requires software transformations to map to the hardware operations. To that end, in this blog we focus on mapping mixed-input matrix multiplication onto the NVIDIA Ampere architecture. We present software techniques addressing data type conversion and layout conformance to map mixed-input matrix multiplication efficiently onto hardware-supported data types and layouts. Our results show that the overhead of additional work in software is minimal and enables performance close to the peak hardware capabilities. The software techniques described here are released in the open-source NVIDIA/CUTLASS repository. Memory footprint for an 175B parameter LLM model with various data types formats. The matrix-multiply-accumulate operation Google’s TPU and NVIDIA’s GPU multiply matrices natively in the hardware by targeting Tensor Cores, which are specialized processing elements to accelerate matrix operations, particularly for AI workloads. In this blog, we focus on NVIDIA Ampere Tensor Cores, which provide the matrix-multiply-accumulate (mma) operation. For the rest of the blog the reference to mma is for Ampere Tensor Cores. The supported data types, shapes, and data layout of the two input matrices (called operands) for the mma operation are fixed in hardware. This means that matrix multiplications with various data types and larger shapes are implemented in the software by tiling the problem onto hardware-supported data types, shapes, and layouts. The Tensor Core mma operation is defined by specifying two input matrices (e.g., A & B, shown below) to produce a result matrix, C. The mma operation natively supports mixed-precision. Mixed-precision Tensor Cores allow mixing input (A and B) data type with the result (C) data type. In contrast, mixed-input matrix multiplication involves mixing the input data types, and it is not supported by the hardware, so it needs to be implemented in the software. Tensor Core operation of M-by-N-by-K on input matrix A of M-by-K and matrix B of K-by-N produces output matrix C of M-by-N. Challenges of mixed-input matrix multiplication hierarchy of memory, including global memory, shared memory, and registers, which are arranged in order of decreasing capacity but increasing speed. NVIDIA Ampere Tensor Core mma operations consume input matrices from registers. Furthermore, input and output matrices are required to conform to a layout of data within a group of 32 threads known as a warp. The supported data type and layout within a warp are fixed for an mma operation, so to implement mixed-input multiplication efficiently, it is necessary to solve the challenges of data type conversion and layout conformance in software. Data type conversion mma operation requires two input matrices with the same data type. Thus, mixed-input matrix multiplication, where one of the operands is stored in U8 in global memory and other in F16, requires a data type conversion from U8 to F16. The conversion will bring two operands to F16, mapping the mixed-input matrix multiplication to hardware-supported mixed-precision Tensor Cores. Given the large number of weights, there are a large number of such operations, and our techniques show how to reduce their latency and improve performance. Layout conformance mma operation also requires the layout of two input matrices, within the registers of a warp, to be conformat with hardware specification. The layout for the input matrix B of U8 data type in mixed-input matrix multiplication (F16 * U8) needs to conform with the converted F16 data type. This is called layout conformance and needs to be achieved in the software. The figure below shows an mma operation consuming matrix A and matrix B from registers to produce matrix C in registers, distributed across one warp. The thread T0 is highlighted and zoomed in to show the weight matrix B goes through data type conversion and needs a layout conformance to be able to map to the hardware-supported Tensor Core operation. The mapping of mixed-input (F32 = F16 * U8) operation in software to natively supported warp-level Tensor Cores in hardware (F32 = F16 * F16). (Original figure source Developing CUDA kernels to push Tensor Cores to the Absolute Limit on NVIDIA A100.) Software strategies addressing challenges NumericArrayConvertor from 4xU8 to 2x(2xF16) in 32-bit registers. Narrower bitwidth shared memory loads: In this approach, threads issue narrow bitwidth memory loads moving the U8 data from shared memory to registers. This results in two 32-bit registers, with each register containing 2xF16 values (shown above for the matrix B’s thread T0). The narrower shared memory load achieves layout conformance directly into registers without needing any shuffles; however, it does not utilize the full shared memory bandwidth. Pre-processing in global memory: An alternative strategy involves rearranging the data within the global memory (one level above the shared memory in memory hierarchy), allowing wider shared memory loads. This approach maximizes the shared memory bandwidth utilization and ensures that the data is loaded in a conformant layout directly in the registers. Although the rearrangement process can be executed offline prior to the LLM deployment, ensuring no impact on the application performance, it introduces an additional, non-trivial hardware-specific pre-processing step that requires an extra program to rearrange the data. NVIDIA/FasterTransformer adopts this method to effectively address layout conformance challenges. Optimized software strategies FastNumericArrayConvertor and FragmentShuffler, respectively. FastNumericArrayConvertor operates on 4xU8 in 32-bit registers without unpacking individual 1xU8 values. Furthermore, it uses less expensive arithmetic operations which reduces the number of instructions and increases the speed of the conversion. The conversion sequence for U8-to-F16 is shown below. The operations use packed 32b registers, avoiding explicit unpacking and packing. FastNumericArrayConvertor uses the permute byte to rearrange bytes of 4xU8 into two registers. Additionally, FastNumericArrayConvertor does not use expensive integer to floating-point conversion instructions and employs vectorized operations to obtain the packed results in two 32-bit registers containing 2x(2xF16) values. The FastNumericArrayConvertor for U8-to-F16 approximately uses six operations, a 1.6× reduction relative to the approach shown above. FastNumericArrayConvertor utilizes permute bytes and packed arithmetic, reducing the number of instructions in the data type conversion. FragmentShuffler handles the layout conformance by shuffling data in a way that allows the use of wider bitwidth load operation, increasing shared memory bandwidth utilization and reducing the total number of operations. NVIDIA Ampere architecture provides a load matrix instruction (ldmatrix). The ldmatrix is a warp-level operation, where 32 threads of a warp move the data from shared memory to registers in the shape and layout that mma matrix A and B consume. The use of ldmatrix reduces the number of load instructions and increases the memory bandwidth utilization. Since the ldmatrix instruction moves U8 data to registers, the layout after the load conforms with U8*U8 mma operation, and not with F16*F16 mma operation. We implemented FragmentShuffler to rearrange the data within registers using shuffle (shfl.sync) operations to achieve the layout conformance. FastNumericArrayConvertor covering data type conversion from U8-to-F16, S8-to-F16, U8-to-BF16, and S8-to-BF16. Performance results our method (shown below in blue and red; varying the data types of matrix A and B) and two mixed-precision data types (shown in green) on an NVIDIA A100 SXM chip. The performance results are shown in FLOPS (higher is better). Notably, the first eight matrix-multipications require additional operations relative to the last two, because the mixed-precision variants directly target hardware-accelerated Tensor Core operations and do not need data type conversion and layout conformance. Even so, our approach demonstrates mixed-input matrix multiplication performance only slightly below or on par with mixed-precision. Mixed-input matrix multiplication performance on NVIDIA A100 40GB SMX4 chip for a compute-bound matrix problem shape m=3456, n=4096, k=2048. Acknowledgements We would like to mention several folks who have contributed through technical brainstorming and improving the blog post including, Quentin Colombet, Jacques Pienaar, Allie Culp, Calin Cascaval, Ashish Gondimalla, Matt Walsh, Marek Kolodziej, and Aman Bhatia. We would like to thank our NVIDIA partners Rawn Henry, Pradeep Ramani, Vijay Thakkar, Haicheng Wu, Andrew Kerr, Matthew Nicely, and Vartika Singh.
Posted by Ameya Velingker, Research Scientist, Google Research, and Balaji Venkatachalam, Software Engineer, Google Graphs, in which objects and their relations are represented as nodes (or vertices) and edges (or links) between pairs of nodes, are ubiquitous in computing and machine learning (ML). For example, social networks, road networks, and molecular structure and interactions are all domains in which underlying datasets have a natural graph structure. ML can be used to learn the properties of nodes, edges, or entire graphs. A common approach to learning on graphs are graph neural networks (GNNs), which operate on graph data by applying an optimizable transformation on node, edge, and global attributes. The most typical class of GNNs operates via a message-passing framework, whereby each layer aggregates the representation of a node with those of its immediate neighbors. Recently, graph transformer models have emerged as a popular alternative to message-passing GNNs. These models build on the success of Transformer architectures in natural language processing (NLP), adapting them to graph-structured data. The attention mechanism in graph transformers can be modeled by an interaction graph, in which edges represent pairs of nodes that attend to each other. Unlike message passing architectures, graph transformers have an interaction graph that is separate from the input graph. The typical interaction graph is a complete graph, which signifies a full attention mechanism that models direct interactions between all pairs of nodes. However, this creates quadratic computational and memory bottlenecks that limit the applicability of graph transformers to datasets on small graphs with at most a few thousand nodes. Making graph transformers scalable has been considered one of the most important research directions in the field (see the first open problem here). A natural remedy is to use a sparse interaction graph with fewer edges. Many sparse and efficient transformers have been proposed to eliminate the quadratic bottleneck for sequences, however, they do not generally extend to graphs in a principled manner. In “Exphormer: Sparse Transformers for Graphs”, presented at ICML 2023, we address the scalability challenge by introducing a sparse attention framework for transformers that is designed specifically for graph data. The Exphormer framework makes use of expander graphs, a powerful tool from spectral graph theory, and is able to achieve strong empirical results on a wide variety of datasets. Our implementation of Exphormer is now available on GitHub. Expander graphs expander graphs, which are sparse yet well-connected graphs that have some useful properties — 1) the matrix representation of the graphs have similar linear-algebraic properties as a complete graph, and 2) they exhibit rapid mixing of random walks, i.e., a small number of steps in a random walk from any starting node is enough to ensure convergence to a “stable” distribution on the nodes of the graph. Expanders have found applications to diverse areas, such as algorithms, pseudorandomness, complexity theory, and error-correcting codes. A common class of expander graphs are d-regular expanders, in which there are d edges from every node (i.e., every node has degree d). The quality of an expander graph is measured by its spectral gap, an algebraic property of its adjacency matrix (a matrix representation of the graph in which rows and columns are indexed by nodes and entries indicate whether pairs of nodes are connected by an edge). Those that maximize the spectral gap are known as Ramanujan graphs — they achieve a gap of d - 2*√(d-1), which is essentially the best possible among d-regular graphs. A number of deterministic and randomized constructions of Ramanujan graphs have been proposed over the years for various values of d. We use a randomized expander construction of Friedman, which produces near-Ramanujan graphs. Expander graphs are at the heart of Exphormer. A good expander is sparse yet exhibits rapid mixing of random walks, making its global connectivity suitable for an interaction graph in a graph transformer model. Exphormer replaces the dense, fully-connected interaction graph of a standard Transformer with edges of a sparse d-regular expander graph. Intuitively, the spectral approximation and mixing properties of an expander graph allow distant nodes to communicate with each other after one stacks multiple attention layers in a graph transformer architecture, even though the nodes may not attend to each other directly. Furthermore, by ensuring that d is constant (independent of the size of the number of nodes), we obtain a linear number of edges in the resulting interaction graph. Exphormer: Constructing a sparse interaction graph Edges from the input graph (local attention) Edges from a constant-degree expander graph (expander attention) Edges from every node to a small set of virtual nodes (global attention) Exphormer builds an interaction graph by combining three types of edges. The resulting graph has good connectivity properties and retains the inductive bias of the input dataset graph while still remaining sparse. 1, 2]. Relation to sparse Transformers for sequences BigBird, which builds an interaction graph by combining different components. BigBird also uses virtual nodes, but, unlike Exphormer, it uses window attention and random attention from an Erdős-Rényi random graph model for the remaining components. Window attention in BigBird looks at the tokens surrounding a token in a sequence — the local neighborhood attention in Exphormer can be viewed as a generalization of window attention to graphs. The Erdős-Rényi graph on n nodes, G(n, p), which connects every pair of nodes independently with probability p, also functions as an expander graph for suitably high p. However, a superlinear number of edges (Ω(n log n)) is needed to ensure that an Erdős-Rényi graph is connected, let alone a good expander. On the other hand, the expanders used in Exphormer have only a linear number of edges. Experimental results GraphGPS framework [3], which combines both message passing and graph transformers and achieves state-of-the-art performance on a number of datasets. We show that replacing dense attention with Exphormer for the graph attention component in the GraphGPS framework allows one to achieve models with comparable or better performance, often with fewer trainable parameters. Furthermore, Exphormer notably allows graph transformer architectures to scale well beyond the usual graph size limits mentioned above. Exphormer can scale up to datasets of 10,000+ node graphs, such as the Coauthor dataset, and even beyond to larger graphs such as the well-known ogbn-arxiv dataset, a citation network, which consists of 170K nodes and 1.1 million edges. Results comparing Exphormer to standard GraphGPS on the five Long Range Graph Benchmark datasets. We note that Exphormer achieved state-of-the-art results on four of the five datasets (PascalVOC-SP, COCO-SP, Peptides-Struct, PCQM-Contact) at the time of the paper’s publication. Results comparing Exphormer to standard GraphGPS on the five Long Range Graph Benchmark datasets. We note that Exphormer achieved state-of-the-art results on four of the five datasets (PascalVOC-SP, COCO-SP, Peptides-Struct, PCQM-Contact) at the time of publication. --> Model PascalVOC-SP F1 score ↑ COCO-SP F1 score ↑ Peptides-Func AP ↑ Peptides-Struct MAE ↓ PCQM-Contact MRR ↑ Standard GraphGPS 0.375 ± 0.011 0.341 ± 0.004 0.654 ± 0.004 0.250 ± 0.001 0.334 ± 0.001 Exphormer (ours) 0.398 ± 0.004 0.346 ± 0.001 0.653 ± 0.004 0.248 ± 0.001 0.364 ± 0.002 --> Finally, we observe that Exphormer, which creates an overlay graph of small diameter via expanders, exhibits the ability to effectively learn long-range dependencies. The Long Range Graph Benchmark is a suite of five graph learning datasets designed to measure the ability of models to capture long-range interactions. Results show that Exphormer-based models outperform standard GraphGPS models (which were previously state-of-the-art on four out of five datasets at the time of publication). Conclusion video from ICML 2023. Acknowledgements We thank our research collaborators Hamed Shirzad and Danica J. Sutherland from The University of British Columbia as well as Ali Kemal Sinop from Google Research. Special thanks to Tom Small for creating the animation used in this post.
요즘 뜨는 키워드
클릭하면 관련 영상·RSS만 보여요. 다시 클릭하거나 전체를 누르면 전체 피드로 돌아갑니다.
PM 2:39기준